In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

In [2]:
class Model(nn.Module):
    def __init__(self,num_classes=200):
        super(Model,self).__init__()
      
        self.first_layer=nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2)
        self.second_layer=nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,padding=2)
        self.third_layer=nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,padding=1)
        self.forth_layer=nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,padding=1)
        self.fifth_layer=nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,padding=1)
        self.lrn=nn.LocalResponseNorm(size=5,alpha=1e-4, beta=0.75)
        self.maxpool=nn.MaxPool2d(kernel_size=3, stride=2) #z=3,s=2
        
        self.fc_shape=None
        print("Fitting dummy data to get the shape of flattening layer")
        if self.fc_shape==None:
            dummy_x=torch.randn(1,3,224,224)
            self.get_flatten_size(dummy_x)
        
        self.fc1=nn.Linear(in_features=self.fc_shape,out_features=4096)
        self.fc2=nn.Linear(in_features=4096,out_features=4096)

        self.output_layer=nn.Linear(4096,num_classes)
        
        self._initialize_weights()
   
    
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                if m.bias is not None:
                    # Biases=1 for conv2/4/5 and FC hidden layers (fc1/fc2)
                    if m in [self.second_layer, self.forth_layer, self.fifth_layer, self.fc1, self.fc2]:
                        nn.init.constant_(m.bias, 1)
                    else:
                        nn.init.constant_(m.bias, 0)
        
    def get_flatten_size(self,x):
        x=F.relu(self.first_layer(x))
        x=self.lrn(x)
        x=self.maxpool(x)

        x=F.relu(self.second_layer(x))
        x=self.lrn(x)
        x=self.maxpool(x)

        x=F.relu(self.third_layer(x))
        x=F.relu(self.forth_layer(x))
        x=F.relu(self.fifth_layer(x))
        x=self.maxpool(x)


        self.fc_shape = x.view(x.size(0), -1).size(1)
        print(f"Ran over the dummy data to get the shape for first flattening layer as : {self.fc_shape}")
    def forward(self,x):
        x=F.relu(self.first_layer(x))
        x=self.lrn(x)
        x=self.maxpool(x)
        x=F.relu(self.second_layer(x))
        x=self.lrn(x)
        x=self.maxpool(x)
        x=F.relu(self.third_layer(x))
        x=F.relu(self.forth_layer(x))
        x=F.relu(self.fifth_layer(x))
        x=self.maxpool(x)
        x=x.view(-1,self.fc_shape)
        x=F.relu(self.fc1(x))
        X=F.dropout(x,0.5)
        x=F.relu(self.fc2(x))
        X=F.dropout(x,0.5)
        x=F.log_softmax(self.output_layer(x),dim=1)
        
        return x
        
        
        

In [3]:
model = Model()
print(model)

Fitting dummy data to get the shape of flattening layer
Ran over the dummy data to get the shape for first flattening layer as : 9216
Model(
  (first_layer): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  (second_layer): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (third_layer): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (forth_layer): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fifth_layer): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lrn): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=4096, bias=True)
  (output_layer): Linear(in_features=4096, out_features=200, bias=True)
)


In [4]:
from datasets import load_dataset

ds = load_dataset("zh-plus/tiny-imagenet")

In [5]:
images = [item['image'] for item in ds['train']]


def compute_pca(dataset):
    pixels = []  
    for img in dataset:
        img = img.convert('RGB')  
        img_tensor = transforms.ToTensor()(img)  
        rearranged = img_tensor.permute(1, 2, 0)  
        flattened = rearranged.reshape(-1, 3).numpy()  
        pixels.append(flattened)
    all_pixels = np.vstack(pixels)
    cov = np.cov(all_pixels.T)  
    eigval, eigvec = np.linalg.eigh(cov)  
    return torch.from_numpy(eigval[::-1].copy()).float(), torch.from_numpy(eigvec[:, ::-1].copy()).float()  

eigval, eigvec = compute_pca(images)



class PCAColorAugmentation(object):
    def __init__(self, eigval, eigvec, alphastd=0.1):
        self.eigval = eigval
        self.eigvec = eigvec
        self.alphastd = alphastd

    def __call__(self, img_tensor):
        alpha = torch.normal(mean=0.0, std=self.alphastd, size=(3,))
        rgb_perturbation = torch.matmul(self.eigvec, alpha * self.eigval)  
        
        for i in range(3):
            img_tensor[i] += rgb_perturbation[i]
        
        return img_tensor.clamp(0, 1)  


class HuggingFaceToPyTorchDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        image = item['image'].convert('RGB')  
        label = item['label']  
        if self.transform:
            image = self.transform(image)
        return image, label

    

train_transforms = transforms.Compose([
    transforms.Lambda(lambda img: img.convert('RGB')), 
    transforms.Resize(256),  
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  
    PCAColorAugmentation(eigval, eigvec, alphastd=0.1),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



test_transforms = transforms.Compose([
    transforms.Lambda(lambda img: img.convert('RGB')),  
    transforms.Resize(256),  
    transforms.TenCrop(224),  
    transforms.Lambda(lambda crops: torch.stack([
        transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])(crop) for crop in crops
    ]))  
])


train_dataset = HuggingFaceToPyTorchDataset(ds['train'], transform=train_transforms)

valid_dataset = HuggingFaceToPyTorchDataset(ds['valid'], transform=test_transforms)





In [6]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=4) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(num_classes=200).to(device)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)  

# criterion = nn.CrossEntropyLoss()
criterion=nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)  




def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)  
            batch_size = images.size(0)
            images = images.view(batch_size * 10, 3, 224, 224)
            outputs = model(images)  
            outputs = outputs.view(batch_size, 10, -1)
            probs = F.softmax(outputs, dim=2)  
            avg_probs = probs.mean(dim=1)  
            loss = criterion(avg_probs, labels) 
            val_loss += loss.item() * batch_size  
            predicted = avg_probs.argmax(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss /= total  
    val_acc = 100 * correct / total
    return val_loss, val_acc

epochs = 10
best_val_acc = 0.0
for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0.0
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if (i + 1) % 50 == 0:
            print(f"Epoch {epoch:03d} | step {i+1:04d} | loss {running_loss/50:.4f}")
            running_loss = 0.0
    
    val_loss, val_acc = evaluate(model, valid_loader)
    print(f"Epoch {epoch:03d} | val-loss {val_loss:.4f} | val-acc {val_acc:.2f}%")
    
    scheduler.step(val_loss)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Checkpoint saved with val-acc {val_acc:.2f}%")
    
    if optimizer.param_groups[0]['lr'] <= 1e-5:
        print("Learning-rate floor reached — finishing training early.")
        break

model.load_state_dict(torch.load('best_model.pth'))
final_val_loss, final_val_acc = evaluate(model, valid_loader)
print(f"Final best val-loss: {final_val_loss:.4f} | val-acc: {final_val_acc:.2f}%")

Fitting dummy data to get the shape of flattening layer
Ran over the dummy data to get the shape for first flattening layer as : 9216
Epoch 001 | step 0050 | loss 5.5930
Epoch 001 | step 0100 | loss 5.2998
Epoch 001 | step 0150 | loss 5.2983
Epoch 001 | step 0200 | loss 5.2984
Epoch 001 | step 0250 | loss 5.2986
Epoch 001 | step 0300 | loss 5.2986
Epoch 001 | step 0350 | loss 5.2987
Epoch 001 | step 0400 | loss 5.2986
Epoch 001 | step 0450 | loss 5.2987
Epoch 001 | step 0500 | loss 5.2986
Epoch 001 | step 0550 | loss 5.2990
Epoch 001 | step 0600 | loss 5.2988
Epoch 001 | step 0650 | loss 5.2987
Epoch 001 | step 0700 | loss 5.2988
Epoch 001 | step 0750 | loss 5.2989
Epoch 001 | val-loss -0.0050 | val-acc 0.50%
Checkpoint saved with val-acc 0.50%
Epoch 002 | step 0050 | loss 5.2983
Epoch 002 | step 0100 | loss 5.2983
Epoch 002 | step 0150 | loss 5.2985
Epoch 002 | step 0200 | loss 5.2986
Epoch 002 | step 0250 | loss 5.2986
Epoch 002 | step 0300 | loss 5.2986
Epoch 002 | step 0350 | loss 