## Import necessary libraries

In [1]:
pip install wandb

Note: you may need to restart the kernel to use updated packages.


In [2]:
### Necessary Imports and dependencies
import os
import shutil
import time
import math
from enum import Enum
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision.transforms import v2
import torchvision.transforms as transforms
from typing import Any, Dict, Union, Type, Callable, Optional, List
from torchvision.models.vision_transformer import MLPBlock
import wandb
import json
from PIL import Image
from torch.utils.data import Dataset
import random
import torchvision.transforms.functional as TF
import numpy as np
import torchvision.transforms as transforms
from PIL import ImageOps, ImageEnhance
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cuda', index=0)

## Epochs and Batch_size

In [4]:
num_epochs = 200
batch_size = 16

## Data preprocessing and Loading

In [5]:
class RandAugment:
    def __init__(self, n=9, m=0.5):
        self.n = n
        self.m = m  # [0, 30] in paper, but we use [0, 1] for simplicity
        self.augment_list = [
            self.auto_contrast, self.equalize, self.rotate, self.solarize, 
            self.color, self.contrast, self.brightness, self.sharpness,
            self.shear_x, self.shear_y, self.translate_x, self.translate_y,
            self.posterize, self.solarize_add, self.invert, self.identity
        ]

    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)
        for op in ops:
            img = op(img)
        return img

    def auto_contrast(self, img):
        return ImageOps.autocontrast(img)

    def equalize(self, img):
        return ImageOps.equalize(img)

    def rotate(self, img):
        return TF.rotate(img, self.m * 30)

    def solarize(self, img):
        return TF.solarize(img, int((1 - self.m) * 255))

    def color(self, img):
        return TF.adjust_saturation(img, 1 + self.m)

    def contrast(self, img):
        return TF.adjust_contrast(img, 1 + self.m)

    def brightness(self, img):
        return TF.adjust_brightness(img, 1 + self.m)

    def sharpness(self, img):
        return ImageEnhance.Sharpness(img).enhance(1 + self.m)

    def shear_x(self, img):
        return TF.affine(img, 0, [0, 0], 1, [self.m, 0])

    def shear_y(self, img):
        return TF.affine(img, 0, [0, 0], 1, [0, self.m])

    def translate_x(self, img):
        return TF.affine(img, 0, [int(self.m * img.size[0] / 3), 0], 1, [0, 0])

    def translate_y(self, img):
        return TF.affine(img, 0, [0, int(self.m * img.size[1] / 3)], 1, [0, 0])

    def posterize(self, img):
        return TF.posterize(img, int((1 - self.m) * 8))

    def solarize_add(self, img):
        return TF.solarize(TF.adjust_brightness(img, 1 + self.m), int((1 - self.m) * 255))

    def invert(self, img):
        return TF.invert(img) if random.random() < 0.5 else img

    def identity(self, img):
        return img

class Mixup(nn.Module):
    def __init__(self, alpha=0.8):
        super().__init__()
        self.alpha = alpha

    def forward(self, batch):
        images, labels = batch
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size = images.size(0)
        index = torch.randperm(batch_size)
        mixed_images = lam * images + (1 - lam) * images[index, :]
        labels_a, labels_b = labels, labels[index]
        return mixed_images, labels_a, labels_b, lam

class CutMix(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha

    def forward(self, batch):
        images, labels = batch
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size, _, H, W = images.shape
        cx = np.random.uniform(0, W)
        cy = np.random.uniform(0, H)
        w = W * np.sqrt(1 - lam)
        h = H * np.sqrt(1 - lam)
        x0 = int(np.clip(cx - w // 2, 0, W))
        y0 = int(np.clip(cy - h // 2, 0, H))
        x1 = int(np.clip(cx + w // 2, 0, W))
        y1 = int(np.clip(cy + h // 2, 0, H))
        index = torch.randperm(batch_size)
        images[:, :, y0:y1, x0:x1] = images[index, :, y0:y1, x0:x1]
        lam = 1 - ((x1 - x0) * (y1 - y0) / (W * H))
        labels_a, labels_b = labels, labels[index]
        return images, labels_a, labels_b, lam

class RandomErasing(nn.Module):
    def __init__(self, probability=0.25, sl=0.02, sh=0.4, r1=0.3, r2=1/0.3):
        super().__init__()
        self.probability = probability
        self.sl = sl
        self.sh = sh
        self.r1 = r1
        self.r2 = r2

    def forward(self, img):
        if random.uniform(0, 1) > self.probability:
            return img
        
        for attempt in range(100):
            area = img.size()[1] * img.size()[2]
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, self.r2)

            h = int(round(np.sqrt(target_area * aspect_ratio)))
            w = int(round(np.sqrt(target_area / aspect_ratio)))

            if w < img.size()[2] and h < img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    img[1, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    img[2, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                else:
                    img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                return img
        return img

class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_classes = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
        smooth_one_hot = one_hot * (1 - self.smoothing) + self.smoothing / n_classes
        log_prob = nn.functional.log_softmax(pred, dim=1)
        return torch.mean(torch.sum(-smooth_one_hot * log_prob, dim=1))

# Updated ImageNet100Dataset
class ImageNet100Dataset(torch.utils.data.Dataset):
    def __init__(self, root_dirs, labels_file, transform=None, augment=None, retain=False, forget=False, forget_label="n01818515"):
        self.transform = transform
        self.augment = augment
        self.images = []
        self.labels = []
        self.label_to_idx = {}
        
        with open(labels_file, 'r') as f:
            label_dict = json.load(f)
        
        unique_labels = sorted(label_dict.keys())
        self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
        
        for root_dir in root_dirs:
            for label in os.listdir(root_dir):
                if retain:
                    if label == forget_label:
                        continue
                    else:
                        label_path = os.path.join(root_dir, label)
                        if os.path.isdir(label_path):
                            for img_name in os.listdir(label_path):
                                img_path = os.path.join(label_path, img_name)
                                self.images.append(img_path)
                                self.labels.append(self.label_to_idx[label])
                elif forget:
                    if label != forget_label:
                        continue
                    else:
                        label_path = os.path.join(root_dir, label)
                        if os.path.isdir(label_path):
                            for img_name in os.listdir(label_path):
                                img_path = os.path.join(label_path, img_name)
                                self.images.append(img_path)
                                self.labels.append(self.label_to_idx[label])
                else:
                    label_path = os.path.join(root_dir, label)
                    if os.path.isdir(label_path):
                        for img_name in os.listdir(label_path):
                            img_path = os.path.join(label_path, img_name)
                            self.images.append(img_path)
                            self.labels.append(self.label_to_idx[label])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        if self.augment:
            image = self.augment(image)
        
        label = torch.tensor(label)
        
        return image, label

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.05, 1.0)),
    transforms.RandomHorizontalFlip(),
    RandAugment(n=9, m=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    RandomErasing(probability=0.25)
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create the datasets
train_dirs = [
    '/kaggle/input/imagenet100/train.X1',
    '/kaggle/input/imagenet100/train.X2',
    '/kaggle/input/imagenet100/train.X3',
    '/kaggle/input/imagenet100/train.X4'
]
val_dir = ['/kaggle/input/imagenet100/val.X']
labels_file = '/kaggle/input/imagenet100/Labels.json'

train_dataset = ImageNet100Dataset(
    root_dirs=train_dirs,
    labels_file=labels_file,
    transform=train_transform,
    retain = True
)

val_dataset = ImageNet100Dataset(
    root_dirs=val_dir,
    labels_file=labels_file,
    transform=val_transform
)

# Custom collate function for Mixup and CutMix
def collate_fn(batch):
    images, labels = torch.utils.data.default_collate(batch)
    if random.random() < 0.5:
        return Mixup(alpha=0.8)((images, labels))
    else:
        return CutMix(alpha=1.0)((images, labels))

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)

## Steps

In [6]:
n = len(train_dataset)
total_steps = round((n * num_epochs) / batch_size)
warmup_try = 10000

## Model

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Adjusted for larger input
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # Add maxpool to handle larger image
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Use adaptive pooling for flexibility with input size
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)  # Add max pooling after the first convolution
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet50(num_classes=100):  # Typically, ResNet50 is used for 1000 classes (e.g., ImageNet)
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)

In [8]:
def weight_decay_param(n, p):
    return p.ndim >= 2 and n.endswith('weight')

# create model
model = ResNet50()
model = nn.DataParallel(model)
model.to('cuda')

wd_params = [p for n, p in model.named_parameters() if weight_decay_param(n, p) and p.requires_grad]
non_wd_params = [p for n, p in model.named_parameters() if not weight_decay_param(n, p) and p.requires_grad]
original_model = model

weight_decay = 0.05
learning_rate = 1e-3

# Label smoothing loss
criterion = LabelSmoothing(smoothing=0.1)

optimizer = torch.optim.AdamW(
    [
        {"params": wd_params, "weight_decay": weight_decay},
        {"params": non_wd_params, "weight_decay": weight_decay},
    ],
    lr=learning_rate,
    betas=(0.9, 0.999)  # Set beta1=0.9 and beta2=0.999
)

warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: step / warmup_try)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - warmup_try)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup, cosine], [warmup_try])

In [9]:
#Change_path_for_the_directory;This is the directory where model weights are to be saved
checkpoint_path = "/kaggle/working/"

def save_checkpoint(state, is_best, path, filename='imagenet_baseline_patchconvcheckpoint.pth.tar'):
    filename = os.path.join(path, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))

def save_checkpoint_step(step, model, best_acc1, optimizer, scheduler, checkpoint_path):
    # Define the filename with the current step
    filename = os.path.join(checkpoint_path, f'BaseLine_with_PE.pt')
    
    # Save the checkpoint
    torch.save({
        'step': step,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }, filename)
    

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count
    
    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)
        
        return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))
        
    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [10]:
log_steps = 855000

wandb.login(key="25ce702053770419f4ace9b1e647163142718b11")

# Initialize a new run
wandb.init(project="aaai-lora", name = "Resnet")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33matharvmittal1[0m ([33matharvmittal1-iit-roorkee[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.18.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20241119_100804-mij143zp[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mResnet[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/atharvmittal1-iit-roorkee/aaai-lora[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/atharvmittal1-iit-roorkee/aaai-lora/runs/mij143zp[0m


In [11]:
start_step = 0
best_acc1 = 0

In [12]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def validate(val_loader, model, criterion, step, use_wandb=False, print_freq=100):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if torch.cuda.is_available():
                images = images.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)
            elif torch.backends.mps.is_available():
                images = images.to('mps')
                target = target.to('mps')

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)

    progress.display_summary()
    
    if use_wandb:        
        log_data = {
            'val/loss': losses.avg,
            'val/acc@1': top1.avg,
            'val/acc@5': top5.avg,
        }
        wandb.log(log_data, step=step)

    return top1.avg

def train(train_loader, val_loader, start_step, total_steps, original_model, model, criterion, optimizer, scheduler, device):
    
    def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        start_step = checkpoint['step']
        model.load_state_dict(checkpoint['state_dict'])
        best_acc1 = checkpoint['best_acc1']
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print(f"Loaded checkpoint. Resuming from step {start_step}")
        return start_step, best_acc1

    
    # Load checkpoint
    start_step, best_acc1 = load_checkpoint("/kaggle/input/750000/pytorch/default/1/BaseLine_with_PE_Res.pt", original_model, optimizer, scheduler)
    
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    print_freq = 100
    log_steps = 855000
    
    progress = ProgressMeter(
        total_steps,
        [batch_time, data_time, losses, top1, top5]
    )

    model.train()
    end = time.time()
    
    def infinite_loader():
        while True:
            yield from train_loader
            
    for step, (images, labels_a, labels_b, lam) in zip(range(start_step + 1, total_steps + 1), infinite_loader()):
        
        
        data_time.update(time.time() - end)

        images = images.to(device, non_blocking=True)
        labels_a = labels_a.to(device, non_blocking=True)
        labels_b = labels_b.to(device, non_blocking=True)
        
        # Convert lam to a tensor if it's not already one
        if not isinstance(lam, torch.Tensor):
            lam = torch.tensor(lam, device=device)
        else:
            lam = lam.to(device, non_blocking=True)

        output = model(images)
        loss = lam * criterion(output, labels_a) + (1 - lam) * criterion(output, labels_b)

        # Compute accuracy (this is an approximation for mixed labels)
        acc1_a, acc5_a = accuracy(output, labels_a, topk=(1, 5))
        acc1_b, acc5_b = accuracy(output, labels_b, topk=(1, 5))
        acc1 = lam * acc1_a + (1 - lam) * acc1_b
        acc5 = lam * acc5_a + (1 - lam) * acc5_b

        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0].item(), images.size(0))
        top5.update(acc5[0].item(), images.size(0))

        loss.backward()
        l2_grads = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

        batch_time.update(time.time() - end)
        end = time.time()
        
        if step % print_freq == 0:
            progress.display(step)
            if wandb:
                with torch.no_grad():
                    l2_params = sum(p.square().sum().item() for _, p in model.named_parameters())
                    
                samples_per_second_per_gpu = images.size(0) / batch_time.val
                samples_per_second = samples_per_second_per_gpu 
                log_data = {
                    "train/loss": losses.val,
                    'train/acc@1': top1.val,
                    'train/acc@5': top5.val,
                    "data_time": data_time.val,
                    "batch_time": batch_time.val,
                    "samples_per_second": samples_per_second,
                    "samples_per_second_per_gpu": samples_per_second_per_gpu,
                    "lr": scheduler.get_last_lr()[0],
                    "l2_grads": l2_grads.item(),
                    "l2_params": math.sqrt(l2_params)
                }
                wandb.log(log_data, step=step)
        
        if ((step % print_freq == 0) and ((step % log_steps != 0) and (step != total_steps))):        
            save_checkpoint_step(step, model, best_acc1, optimizer, scheduler, checkpoint_path)
                
        if step % log_steps == 0:
            acc1 = validate(val_loader, original_model, criterion, step)
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
            
            save_checkpoint({
                'step': step,
                'state_dict': original_model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }, is_best, checkpoint_path)

            break
            
        elif step == total_steps:
            acc1 = validate(val_loader, original_model, criterion, step)
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
            
            save_checkpoint({
                'step': step,
                'state_dict': original_model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }, is_best, checkpoint_path)

        scheduler.step()
        torch.cuda.empty_cache()

# Use the modified train function
train(train_loader, val_loader, start_step, total_steps, original_model, model, criterion, optimizer, scheduler, device)

Loading checkpoint from /kaggle/input/750000/pytorch/default/1/BaseLine_with_PE_Res.pt


  checkpoint = torch.load(checkpoint_path)


Loaded checkpoint. Resuming from step 749900


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


[ 750000/1608750]	Time  0.464 ( 0.519)	Data  0.368 ( 0.399)	Loss 3.9149e+00 (4.0026e+00)	Acc@1   5.69 ( 14.28)	Acc@5  35.28 ( 34.38)
[ 750100/1608750]	Time  0.519 ( 0.512)	Data  0.416 ( 0.402)	Loss 3.6971e+00 (4.0089e+00)	Acc@1   6.01 ( 14.15)	Acc@5  48.05 ( 34.18)
[ 750200/1608750]	Time  0.478 ( 0.514)	Data  0.385 ( 0.408)	Loss 3.7940e+00 (4.0396e+00)	Acc@1  11.30 ( 13.31)	Acc@5  42.83 ( 32.75)
[ 750300/1608750]	Time  0.549 ( 0.517)	Data  0.439 ( 0.411)	Loss 4.1902e+00 (4.0472e+00)	Acc@1  15.00 ( 13.32)	Acc@5  27.50 ( 32.68)
[ 750400/1608750]	Time  0.509 ( 0.521)	Data  0.398 ( 0.414)	Loss 3.9136e+00 (4.0458e+00)	Acc@1  14.12 ( 13.54)	Acc@5  32.87 ( 32.67)
[ 750500/1608750]	Time  0.470 ( 0.524)	Data  0.369 ( 0.418)	Loss 3.6907e+00 (4.0402e+00)	Acc@1  18.24 ( 13.66)	Acc@5  36.48 ( 33.09)
[ 750600/1608750]	Time  0.531 ( 0.526)	Data  0.420 ( 0.420)	Loss 4.3894e+00 (4.0410e+00)	Acc@1   9.67 ( 13.63)	Acc@5  18.75 ( 32.97)
[ 750700/1608750]	Time  0.568 ( 0.528)	Data  0.463 ( 0.421)	Loss 4.33

In [13]:
wandb.finish()

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                 batch_time ▅█▅▄▄▃▂▃▃▂▂▂▂▃▂▃▄▂▄▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▃▂▂▁▁▂
[34m[1mwandb[0m:                  data_time █▇██▃▄▂▃▃▃▂▄▃▃▄▄▃▅▄▃▂▂▂▁▄▂▃▂▁▂▃▃▃▃▃▂▁▁▃▁
[34m[1mwandb[0m:                   l2_grads ▃▅▅▄▁▂▃▁▄▅▅▆▅▁▄▆▄▄▇▃▄▅▂▅▅▅▄▃▇▆▅▅█▇▆▅▇▆▄▇
[34m[1mwandb[0m:                  l2_params ███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁
[34m[1mwandb[0m:                         lr █████▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▁▁▁▁
[34m[1mwandb[0m:         samples_per_second ▃▁▃▅▄▆▅▆▆▃▆▇▅▆▆▅▇▇█▇▇▆▇▇▆▇█▇▅█▇▅▅▇▇▇▅▅▇▇
[34m[1mwandb[0m: samples_per_second_per_gpu ▂▁▁▃▅▅▄▆▄▇▅▆▆▇▄▄▅█▆▆▆▆▆▅▇▇▂▇▇▆▅▇▇█▇▇▇▅▇▇
[34m[1mwandb[0m:                train/acc@1 ▄▂▆▂▆▁▂▃▄▁▅▇▆▅█▇▃▁▅▂▄▄▄▆▆▄▅▃▅▅▅▄▄▅▆▂█▃▅▄
[34m[1mwandb[0m:                train/acc@5 ▂▃▂▃▃▄█▁▃▂▃▆▃█▆▂▂▅▆▆▂▃▄▄▃▇▄▂▆▄▄▂▄▅▂▃▅▂▃▄
[34m[1mwandb[0m:                 train/loss ▃▇▃▄▆▂▃

In [14]:
# # Test dataset paths first
# print("=== Testing Dataset Paths ===")
# for path in train_dirs + val_dir:
#     print(f"Checking path {path}: {os.path.exists(path)}")
# print(f"Checking labels file {labels_file}: {os.path.exists(labels_file)}")

# # Test label loading
# print("\n=== Testing Label Loading ===")
# with open(labels_file, 'r') as f:
#     label_dict = json.load(f)
# print("Number of unique labels:", len(label_dict))
# print("First few labels:", list(label_dict.keys())[:5])

# # Test dataset creation
# print("\n=== Testing Dataset Creation ===")
# try:
#     test_dataset = ImageNet100Dataset(
#         root_dirs=train_dirs,
#         labels_file=labels_file,
#         transform=train_transform
#     )
#     print("Training dataset size:", len(test_dataset))
    
#     # Test single item loading
#     img, label = test_dataset[0]
#     print("Single image shape:", img.shape)
#     print("Single label:", label)
    
#     # Visualize transformations on one image
#     plt.figure(figsize=(15, 3))
    
#     # Original image
#     img_orig = Image.open(test_dataset.images[0]).convert('RGB')
#     plt.subplot(1, 4, 1)
#     plt.imshow(img_orig)
#     plt.title("Original")
    
#     # Apply transformations 3 times to show randomness
#     for i in range(3):
#         img_transformed, _ = test_dataset[0]
#         plt.subplot(1, 4, i+2)
#         plt.imshow(img_transformed.permute(1, 2, 0).clip(0, 1))
#         plt.title(f"Transform {i+1}")
    
#     plt.show()
    
# except Exception as e:
#     print("Error in dataset creation:", str(e))

# # Test dataloader
# print("\n=== Testing DataLoader ===")
# try:
#     test_loader = torch.utils.data.DataLoader(
#         test_dataset,
#         batch_size=batch_size,
#         shuffle=True,
#         collate_fn=collate_fn,
#         num_workers=2
#     )
    
#     # Get one batch
#     batch = next(iter(test_loader))
#     print("Batch length:", len(batch))  # Should be 4 (images, labels_a, labels_b, lam)
#     images, labels_a, labels_b, lam = batch
#     print("Batch shapes:")
#     print(f"Images: {images.shape}")
#     print(f"Labels A: {labels_a.shape}")
#     print(f"Labels B: {labels_b.shape}")
#     print(f"Lambda: {lam}")
    
# except Exception as e:
#     print("Error in dataloader:", str(e))
    
# print("=== Testing Model ===")

# # Test model creation
# try:
#     test_model = SimpleVisionTransformer(
#         image_size=256,
#         patch_size=16,
#         num_layers=12,
#         num_heads=6,
#         hidden_dim=384,
#         mlp_dim=1536,
#     )
#     print("Model created successfully")
    
#     # Print model summary
#     print("\nModel Architecture:")
#     print(test_model)
    
#     # Count parameters
#     total_params = sum(p.numel() for p in test_model.parameters())
#     trainable_params = sum(p.numel() for p in test_model.parameters() if p.requires_grad)
#     print(f"\nTotal parameters: {total_params:,}")
#     print(f"Trainable parameters: {trainable_params:,}")
    
#     # Test forward pass
#     print("\nTesting forward pass...")
#     test_model.to(device)
#     test_input = torch.randn(2, 3, 256, 256).to(device)  # batch size of 2
#     with torch.no_grad():
#         output = test_model(test_input)
#     print("Output shape:", output.shape)
    
#     # Test memory usage
#     if torch.cuda.is_available():
#         print("\nGPU Memory Usage:")
#         print(f"Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
#         print(f"Cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
    
# except Exception as e:
#     print("Error in model testing:", str(e))
    
# print("=== Testing Loss Function and Optimizer ===")

# try:
#     # Test loss function
#     print("\nTesting Label Smoothing Loss...")
#     test_criterion = LabelSmoothing(smoothing=0.1)
#     test_predictions = torch.randn(4, 100)  # 4 samples, 100 classes
#     test_targets = torch.tensor([0, 1, 2, 3])
#     test_loss = test_criterion(test_predictions, test_targets)
#     print("Test loss value:", test_loss.item())
    
#     # Test with different smoothing values
#     for smoothing in [0.0, 0.1, 0.2]:
#         criterion = LabelSmoothing(smoothing=smoothing)
#         loss = criterion(test_predictions, test_targets)
#         print(f"Loss with smoothing {smoothing}: {loss.item()}")
    
#     # Test optimizer
#     print("\nTesting Optimizer...")
#     test_model = SimpleVisionTransformer(
#         image_size=256,
#         patch_size=16,
#         num_layers=12,
#         num_heads=6,
#         hidden_dim=384,
#         mlp_dim=1536,
#     ).to(device)
    
#     # Test parameter grouping
#     wd_params = [p for n, p in test_model.named_parameters() if weight_decay_param(n, p) and p.requires_grad]
#     non_wd_params = [p for n, p in test_model.named_parameters() if not weight_decay_param(n, p) and p.requires_grad]
#     print(f"Parameters with weight decay: {len(wd_params)}")
#     print(f"Parameters without weight decay: {len(non_wd_params)}")
    
#     # Test optimizer creation
#     test_optimizer = torch.optim.AdamW(
#         [
#             {"params": wd_params, "weight_decay": weight_decay},
#             {"params": non_wd_params, "weight_decay": 0.0},
#         ],
#         lr=learning_rate,
#     )
#     print("Optimizer created successfully")
    
#     # Test scheduler
#     print("\nTesting Learning Rate Scheduler...")
#     test_warmup = torch.optim.lr_scheduler.LambdaLR(test_optimizer, lr_lambda=lambda step: step / warmup_try)
#     test_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(test_optimizer, T_max=total_steps - warmup_try)
#     test_scheduler = torch.optim.lr_scheduler.SequentialLR(test_optimizer, [test_warmup, test_cosine], [warmup_try])
    
#     # Print learning rates at different steps
#     steps_to_check = [0, warmup_try//2, warmup_try, total_steps//2, total_steps]
#     print("\nLearning rate at different steps:")
#     for step in steps_to_check:
#         for _ in range(step):
#             test_scheduler.step()
#         print(f"Step {step}: {test_scheduler.get_last_lr()[0]}")
    
# except Exception as e:
#     print("Error in loss/optimizer testing:", str(e))
    
# print("=== Testing Training Setup ===")

# # Test CUDA availability
# print("\nTesting CUDA Setup:")
# print(f"CUDA available: {torch.cuda.is_available()}")
# print(f"Current device: {device}")
# if torch.cuda.is_available():
#     print(f"GPU name: {torch.cuda.get_device_name(0)}")

# # Test Wandb initialization
# print("\nTesting Wandb Setup...")
# try:
#     wandb.init(project="ICLR_2025_Blog", name="test_run")
#     print("Wandb initialized successfully")
    
#     # Test logging
#     wandb.log({
#         "test_metric": 0.5,
#         "test_loss": 0.1
#     })
#     print("Wandb logging successful")
#     wandb.finish()
# except Exception as e:
#     print("Error in wandb setup:", str(e))

# # Test training loop components
# print("\nTesting Training Loop Components...")
# try:
#     # Test progress meter
#     batch_time = AverageMeter('Time', ':6.3f')
#     data_time = AverageMeter('Data', ':6.3f')
#     losses = AverageMeter('Loss', ':.4e')
#     top1 = AverageMeter('Acc@1', ':6.2f')
#     progress = ProgressMeter(
#         total_steps,
#         [batch_time, data_time, losses, top1],
#         prefix="Test: "
#     )
    
#     # Update meters
#     batch_time.update(0.5)
#     losses.update(0.1)
#     top1.update(95.0)
    
#     # Test display
#     progress.display(0)
    
#     # Test checkpoint saving
#     print("\nTesting Checkpoint Saving...")
#     dummy_state = {
#         'step': 0,
#         'state_dict': test_model.state_dict(),
#         'best_acc1': 0,
#         'optimizer': test_optimizer.state_dict(),
#         'scheduler': test_scheduler.state_dict()
#     }
#     save_checkpoint(dummy_state, False, checkpoint_path)
#     print(f"Checkpoint saved to {checkpoint_path}")
    
# except Exception as e:
#     print("Error in training loop components:", str(e))

# print("\nAll tests completed!")