In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import numpy as np
import h5py
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import gc

# Check if CUDA is available and set memory management
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up CUDA memory management
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    # Enable memory optimization
    torch.backends.cudnn.benchmark = True

# Custom Dataset for Galaxy10
def load_galaxy10_data():
    with h5py.File("Galaxy10.h5", "r") as f:
        images = np.array(f["images"])  # Shape: (N, H, W, C)
        labels = np.array(f["ans"])  # Shape: (N,)
    return images, labels

class Galaxy10Dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        # Convert to PIL Image (Ensure 3 channels)
        image = Image.fromarray(image[:, :, :3])
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define Transformations - Further reduce image size to save memory
transform = transforms.Compose([
    transforms.Resize((96, 96)),  # Reduced from 112x112 to 96x96
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load Dataset - with memory optimization
def get_dataloaders():
    images, labels = load_galaxy10_data()
    dataset = Galaxy10Dataset(images, labels, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    # Further reduce batch size
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)  # Reduced to 16
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)  # Reduced to 16
    
    return train_loader, test_loader

# Convolutional Token Embedding with reduced parameters
class ConvTokenEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvTokenEmbedding, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
                             stride=stride, padding=padding)
        self.norm = nn.LayerNorm(out_channels)
        self.activation = nn.GELU()
        self.out_channels = out_channels
        
    def forward(self, x):
        x = self.conv(x)  # [B, C, H, W]
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        x = self.norm(x)  # LayerNorm over channel dimension
        x = self.activation(x)
        x = x.permute(0, 3, 1, 2)  # [B, C, H, W]
        return x

# Memory-efficient Convolutional Projection
class ConvolutionalProjection(nn.Module):
    def __init__(self, dim, kernel_size=3, stride=1, padding=1):
        super(ConvolutionalProjection, self).__init__()
        self.conv = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=dim)
        
    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.conv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

# Simplified Attention with lower memory footprint
class ConvAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super(ConvAttention, self).__init__()
        self.num_heads = num_heads
        assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}"
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Use separate projections to save memory during backward pass
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, H, W):
        B, N, C = x.shape
        
        # Project q, k, v separately to save memory
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Attention computation with memory efficiency
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        # Clean up intermediate tensors
        del q, k, v, attn
        return x

# Reduced MLP Block
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super(MLP, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # Further reduce hidden features
        hidden_features = int(hidden_features * 0.5)  # 50% reduction in hidden dim
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# Optimized Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = ConvAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio * 0.5), drop=drop)  # 50% reduction
        
    def forward(self, x, H, W):
        # Use clone to save memory in backward pass
        x_norm = self.norm1(x)
        x = x + self.attn(x_norm, H, W)
        x = x + self.mlp(self.norm2(x))
        return x

# Simplified CvT Stage
class CvTStage(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, depth, num_heads, mlp_ratio=4.):
        super(CvTStage, self).__init__()
        self.embedding = ConvTokenEmbedding(in_channels, out_channels, kernel_size, stride, padding)
        self.blocks = nn.ModuleList([
            TransformerBlock(out_channels, num_heads, mlp_ratio=mlp_ratio)
            for _ in range(depth)
        ])
        
    def forward(self, x):
        x = self.embedding(x)  # [B, C, H, W]
        H, W = x.shape[2], x.shape[3]
        x = x.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        
        for block in self.blocks:
            x = block(x, H, W)
            
        return x, H, W

# Much reduced CvT Model
class CvT(nn.Module):
    def __init__(self, num_classes=10, in_channels=3):
        super(CvT, self).__init__()
        # Stage 1 parameters - reduced dimensions
        s1_embed_dim = 48  # Reduced from 64
        s1_kernel_size = 7
        s1_stride = 4
        s1_padding = 3
        s1_depth = 1
        s1_num_heads = 1
        
        # Stage 2 parameters - reduced dimensions
        s2_embed_dim = 96  # Reduced from 192
        s2_kernel_size = 3
        s2_stride = 2
        s2_padding = 1
        s2_depth = 1  # Reduced to 1
        s2_num_heads = 3
        
        # Stage 3 parameters - reduced dimensions
        s3_embed_dim = 192  # Reduced from 384
        s3_kernel_size = 3
        s3_stride = 2
        s3_padding = 1
        s3_depth = 4  # Reduced from 6 to 4
        s3_num_heads = 6
        
        # Define stages
        self.stage1 = CvTStage(in_channels, s1_embed_dim, s1_kernel_size, s1_stride, s1_padding, s1_depth, s1_num_heads)
        self.stage2 = CvTStage(s1_embed_dim, s2_embed_dim, s2_kernel_size, s2_stride, s2_padding, s2_depth, s2_num_heads)
        self.stage3 = CvTStage(s2_embed_dim, s3_embed_dim, s3_kernel_size, s3_stride, s3_padding, s3_depth, s3_num_heads)
        
        # Classification head
        self.norm = nn.LayerNorm(s3_embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, s3_embed_dim))
        self.head = nn.Linear(s3_embed_dim, num_classes)
        
        # Initialize cls token
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Process through stages
        x, H1, W1 = self.stage1(x)
        x = x.permute(0, 2, 1).reshape(B, 48, H1, W1)  # [B, C, H, W]
        
        x, H2, W2 = self.stage2(x)
        x = x.permute(0, 2, 1).reshape(B, 96, H2, W2)  # [B, C, H, W]
        
        x, H3, W3 = self.stage3(x)  # [B, H*W, C]
        
        # Add classification token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Apply layer norm
        x = self.norm(x)
        
        # Use only the cls token for classification
        x = x[:, 0]
        
        # Classification head
        x = self.head(x)
        
        return x

def train_and_evaluate():
    # Get dataloaders
    train_loader, test_loader = get_dataloaders()
    
    # Increase gradient accumulation steps
    accumulation_steps = 4  # Increased from 2 to 4
    
    # Initialize Model, Loss, Optimizer
    model = CvT().to(device)
    
    # Enable mixed precision training if available
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    
    # Print model summary
    print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
    
    # Training Loop with Early Stopping and gradient accumulation
    num_epochs = 50
    early_stopping_patience = 5
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Empty cache before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        optimizer.zero_grad()  # Zero gradients at the beginning of each epoch
        
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)
            
            # Use mixed precision training if available
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels) / accumulation_steps
                
                # Scale gradients and call backward
                scaler.scale(loss).backward()
                
                # Update weights every accumulation_steps batches
                if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                # Regular training for CPU
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps
                loss.backward()
                
                if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                    optimizer.step()
                    optimizer.zero_grad()
            
            running_loss += loss.item() * accumulation_steps
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            progress_bar.set_postfix(loss=loss.item() * accumulation_steps, accuracy=100 * correct / total)
            
            # Clear memory for this batch
            del images, labels, outputs, loss
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            gc.collect()
        
        epoch_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")
        
        # Early Stopping Check
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            # Save model checkpoint
            torch.save(model.state_dict(), 'best_cvt_model.pth')
            print(f"Model saved at epoch {epoch+1}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered.")
                break
    
    print("Training complete!")
    
    # Load best model for testing
    model.load_state_dict(torch.load('best_cvt_model.pth'))
    
    # Testing Loop
    model.eval()
    test_correct = 0
    test_total = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            
            # Use mixed precision for evaluation as well
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
            else:
                outputs = model(images)
                
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            # Clear memory for this batch
            del images, labels, outputs, probs
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
    test_accuracy = 100 * test_correct / test_total
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    
    # Classification Report
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, digits=3))

if __name__ == "__main__":
    train_and_evaluate()

Using device: cuda


  scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None


Model created with 1178170 parameters


  with torch.cuda.amp.autocast():
Epoch 1/50: 100%|██████████| 1090/1090 [02:36<00:00,  6.95it/s, accuracy=30.7, loss=1.35]


Epoch 1/50, Loss: 1.8238, Accuracy: 30.67%
Model saved at epoch 1


Epoch 2/50: 100%|██████████| 1090/1090 [02:23<00:00,  7.57it/s, accuracy=31.7, loss=1.75]


Epoch 2/50, Loss: 1.7475, Accuracy: 31.69%
Model saved at epoch 2


Epoch 3/50: 100%|██████████| 1090/1090 [02:25<00:00,  7.48it/s, accuracy=31.3, loss=1.8] 


Epoch 3/50, Loss: 1.7481, Accuracy: 31.32%


Epoch 4/50: 100%|██████████| 1090/1090 [02:27<00:00,  7.40it/s, accuracy=31.7, loss=1.62]


Epoch 4/50, Loss: 1.7471, Accuracy: 31.71%
Model saved at epoch 4


Epoch 5/50: 100%|██████████| 1090/1090 [02:21<00:00,  7.70it/s, accuracy=31.7, loss=1.23]


Epoch 5/50, Loss: 1.7469, Accuracy: 31.72%
Model saved at epoch 5


Epoch 6/50: 100%|██████████| 1090/1090 [03:12<00:00,  5.67it/s, accuracy=31.6, loss=2.25]


Epoch 6/50, Loss: 1.7476, Accuracy: 31.64%


Epoch 7/50: 100%|██████████| 1090/1090 [02:36<00:00,  6.95it/s, accuracy=31.9, loss=1.82]


Epoch 7/50, Loss: 1.7473, Accuracy: 31.86%


Epoch 8/50: 100%|██████████| 1090/1090 [02:19<00:00,  7.83it/s, accuracy=31.9, loss=1.64]


Epoch 8/50, Loss: 1.7473, Accuracy: 31.93%


Epoch 9/50: 100%|██████████| 1090/1090 [02:22<00:00,  7.64it/s, accuracy=31.7, loss=1.81]


Epoch 9/50, Loss: 1.7475, Accuracy: 31.65%


Epoch 10/50: 100%|██████████| 1090/1090 [02:23<00:00,  7.61it/s, accuracy=31.7, loss=1.7] 


Epoch 10/50, Loss: 1.7470, Accuracy: 31.68%
Early stopping triggered.
Training complete!


  with torch.cuda.amp.autocast():
Testing: 100%|██████████| 273/273 [00:03<00:00, 78.86it/s]

Test Accuracy: 32.22%
Classification Report:
              precision    recall  f1-score   support

           0      0.000     0.000     0.000       692
           1      0.322     1.000     0.487      1404
           2      0.000     0.000     0.000      1251
           3      0.000     0.000     0.000        76
           4      0.000     0.000     0.000       316
           5      0.000     0.000     0.000         5
           6      0.000     0.000     0.000       125
           7      0.000     0.000     0.000       218
           8      0.000     0.000     0.000       172
           9      0.000     0.000     0.000        98

    accuracy                          0.322      4357
   macro avg      0.032     0.100     0.049      4357
weighted avg      0.104     0.322     0.157      4357




  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
