In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import numpy as np
import h5py
from PIL import Image

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Custom Dataset for Galaxy10
def load_galaxy10_data():
    with h5py.File("Galaxy10.h5", "r") as f:
        images = np.array(f["images"])  # Shape: (N, H, W, C)
        labels = np.array(f["ans"])  # Shape: (N,)
    return images, labels

class Galaxy10Dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        # Convert to PIL Image (Ensure 3 channels)
        image = Image.fromarray(image[:, :, :3])
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load Dataset
images, labels = load_galaxy10_data()
dataset = Galaxy10Dataset(images, labels, transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define CvT Model
class CvT(nn.Module):
    def __init__(self, num_classes=10):
        super(CvT, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(256 * 28 * 28, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Initialize Model, Loss, Optimizer
model = CvT().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)

# Training Loop
num_epochs = 30
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%")

print("Training complete!")

# Testing Loop
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

test_accuracy = 100 * test_correct / test_total
print(f"Test Accuracy: {test_accuracy:.2f}%")


Epoch 1/30, Loss: 1.4385, Accuracy: 42.54%
Epoch 2/30, Loss: 1.2111, Accuracy: 51.11%
Epoch 3/30, Loss: 1.1365, Accuracy: 53.71%
Epoch 4/30, Loss: 1.0897, Accuracy: 55.34%
Epoch 5/30, Loss: 1.0422, Accuracy: 57.28%
Epoch 6/30, Loss: 0.9975, Accuracy: 59.53%
Epoch 7/30, Loss: 0.9779, Accuracy: 60.67%
Epoch 8/30, Loss: 0.9399, Accuracy: 62.35%
Epoch 9/30, Loss: 0.9123, Accuracy: 63.89%
Epoch 10/30, Loss: 0.8764, Accuracy: 65.34%
Epoch 11/30, Loss: 0.8459, Accuracy: 67.17%
Epoch 12/30, Loss: 0.8058, Accuracy: 69.46%
Epoch 13/30, Loss: 0.7778, Accuracy: 70.51%
Epoch 14/30, Loss: 0.7380, Accuracy: 72.76%
Epoch 15/30, Loss: 0.7153, Accuracy: 73.24%
Epoch 16/30, Loss: 0.6848, Accuracy: 74.62%
Epoch 17/30, Loss: 0.6603, Accuracy: 75.80%
Epoch 18/30, Loss: 0.6356, Accuracy: 76.76%
Epoch 19/30, Loss: 0.6138, Accuracy: 77.35%
Epoch 20/30, Loss: 0.5991, Accuracy: 77.86%
Epoch 21/30, Loss: 0.5796, Accuracy: 78.74%
Epoch 22/30, Loss: 0.5734, Accuracy: 79.10%
Epoch 23/30, Loss: 0.5520, Accuracy: 80.0