In [None]:
class BasicBlockWithDropout(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, dropout_rate=0.5):
        super(BasicBlockWithDropout, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet18WithDropoutBlocks(nn.Module):
    def __init__(self, dropout_rate=0.5, num_classes=1000):
        super(ResNet18WithDropoutBlocks, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(BasicBlockWithDropout, 64, 2, stride=1, dropout_rate=dropout_rate)
        self.layer2 = self._make_layer(BasicBlockWithDropout, 128, 2, stride=2, dropout_rate=dropout_rate)
        self.layer3 = self._make_layer(BasicBlockWithDropout, 256, 2, stride=2, dropout_rate=dropout_rate)
        self.layer4 = self._make_layer(BasicBlockWithDropout, 512, 2, stride=2, dropout_rate=dropout_rate)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Dropout(p=dropout_rate),  # Add dropout before the final fully connected layer
            nn.Linear(512 * BasicBlockWithDropout.expansion, num_classes)
        )

    def _make_layer(self, block, planes, blocks, stride, dropout_rate):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, dropout_rate))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Example usage:
model = ResNet18WithDropoutBlocks(dropout_rate=0.5, num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Training loop (simplified)
for epoch in range(num_epochs):
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    # Validation loop (simplified)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss}")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.optim.lr_scheduler import StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau

# Define your model
class CustomResNet50(nn.Module):
    def __init__(self, num_features=2048, num_classes=10, dropout_rate=0.5):
        super(CustomResNet50, self).__init__()
        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Identity()  # Remove the last layer
        
        # Custom classifier layer
        self.classifier_layer = nn.Sequential(
            nn.Dropout(p=dropout_rate, inplace=False),  # Dropout
            nn.Linear(num_features, num_classes),       # Linear layer
            nn.BatchNorm1d(num_classes)                 # Batch normalization
        )

    def forward(self, x):
        features = self.model(x)
        output = self.classifier_layer(features)
        return output

# Example usage
model = CustomResNet50(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Different schedulers to test
schedulers = {
    'StepLR': StepLR(optimizer, step_size=5, gamma=0.1),
    'MultiStepLR': MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1),
    'ExponentialLR': ExponentialLR(optimizer, gamma=0.9),
    'ReduceLROnPlateau': ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
}

# Dummy data loaders (replace with actual data loaders)
train_loader = [(torch.randn(16, 3, 224, 224), torch.randint(0, 10, (16,))) for _ in range(50)]
val_loader = [(torch.randn(16, 3, 224, 224), torch.randint(0, 10, (16,))) for _ in range(10)]

# Training function
def train_model(model, optimizer, scheduler, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, target in val_loader:
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        
        # Update scheduler
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(val_loss)
        else:
            scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss}")

# Test each scheduler
for name, scheduler in schedulers.items():
    print(f"\nTesting scheduler: {name}")
    model = CustomResNet50(num_classes=10)  # Reinitialize model for each test
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # Reinitialize optimizer
    train_model(model, optimizer, scheduler, num_epochs=10)


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import time
from sklearn.metrics import f1_score, accuracy_score
from collections import defaultdict
import pandas as pd

# Initialize best accuracy and initialize model weights
class_names = dataset.classes
num_classes = len(class_names)
num_epochs = 200

# Initialize dictionaries to store metrics for each model
metrics = defaultdict(lambda: defaultdict(list))

# Define the models you want to compare
model_names = ['resnet18', 'resnet50', 'resnet101']

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Define the learning rate
lr = 0.001

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize the TensorBoard writer
tb = SummaryWriter()

def train_model(model, dataloader, loss_fn, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    scheduler.step()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    epoch_f1 = f1_score(all_labels, all_preds, average='weighted')

    return epoch_loss, epoch_acc, epoch_f1

def valid_model(model, dataloader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = loss_fn(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    epoch_f1 = f1_score(all_labels, all_preds, average='weighted')

    return epoch_loss, epoch_acc, epoch_f1

def save_model(model, target_dir, model_name):
    torch.save(model.state_dict(), f'{target_dir}/{model_name}')

# Start the training process for each model
for model_name in model_names:
    print(f"[INFO]: Training {model_name} model")

    # Load the model
    model = getattr(models, model_name)(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    model = model.to(device)

    # Define the optimizer and scheduler
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Train and evaluate the model
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f"[INFO]: Epoch {epoch+1} of {num_epochs} for {model_name}")

        avg_train_loss, avg_train_acc, train_weighted_f1 = train_model(model, train_dataloader, loss_fn, optimizer, exp_lr_scheduler, device)
        metrics[model_name]['train_loss'].append(avg_train_loss)
        metrics[model_name]['train_acc'].append(avg_train_acc.item())
        metrics[model_name]['train_f1'].append(train_weighted_f1)

        tb.add_scalar(f'{model_name} Training Loss', avg_train_loss, epoch)
        tb.add_scalar(f'{model_name} Training Accuracy', avg_train_acc, epoch)
        tb.add_scalar(f'{model_name} Training F1 Score', train_weighted_f1, epoch)

        avg_valid_loss, avg_valid_acc, valid_weighted_f1 = valid_model(model, valid_dataloader, loss_fn, device)
        metrics[model_name]['valid_loss'].append(avg_valid_loss)
        metrics[model_name]['valid_acc'].append(avg_valid_acc.item())
        metrics[model_name]['valid_f1'].append(valid_weighted_f1)

        tb.add_scalar(f'{model_name} Validation Loss', avg_valid_loss, epoch)
        tb.add_scalar(f'{model_name} Validation Accuracy', avg_valid_acc, epoch)
        tb.add_scalar(f'{model_name} Validation F1 Score', valid_weighted_f1, epoch)

        print(f'[INFO] {model_name} - Training Loss: {avg_train_loss:.4f} | Acc: {avg_train_acc:.4f}% | Weighted F1: {train_weighted_f1:.4f}')
        print(f'[INFO] {model_name} - Validation Loss: {avg_valid_loss:.4f} | Acc: {avg_valid_acc:.4f}% | Weighted F1: {valid_weighted_f1:.4f}')

        if avg_valid_acc > best_acc:
            best_acc = avg_valid_acc
            print(f'[INFO] Improvement detected for {model_name}, saving best model.')
            save_model(model, target_dir, f'{model_name}_best.pth')

        print('-'*80)

    save_model(model, target_dir, f'{model_name}_last.pth')

print("Training complete for all models.")

tb.close()

# Save metrics to file or display them
metrics_df = pd.DataFrame(metrics)
import ace_tools as tools; tools.display_dataframe_to_user("Model Metrics", metrics_df)

# Plotting the results
from utilities import visual

for model_name in model_names:
    settings = {
        'num_epochs': num_epochs,
        'batch_size': batch_size,
        'criterion': 'CrossEntropyLoss',
        'optimizer': 'Adam',
        'lr': lr,
    }
    
    history = {
        'train_loss': metrics[model_name]['train_loss'],
        'valid_loss': metrics[model_name]['valid_loss'],
        'train_acc': metrics[model_name]['train_acc'],
        'valid_acc': metrics[model_name]['valid_acc'],
        'train_weighted_f1': metrics[model_name]['train_f1'],
        'valid_weighted_f1': metrics[model_name]['valid_f1'],
    }
    
    title = f'Performance of {model_name}'
    visual.plot_results(history, target_dir, title)


# Test

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import math

# Install the apex library if not already installed for LAMB optimizer
# !pip install apex

from apex.optimizers import FusedLAMB

# Define the learning rate schedule
class CustomLRScheduler:
    def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr, final_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.base_lr = base_lr
        self.final_lr = final_lr
        self.current_epoch = 0

    def step(self, epoch):
        self.current_epoch = epoch
        if epoch < self.warmup_epochs:
            lr = self.base_lr * (epoch / self.warmup_epochs)
        elif epoch < self.total_epochs - 10:
            lr = self.final_lr + (self.base_lr - self.final_lr) * 0.5 * (1 + math.cos(math.pi * (epoch - self.warmup_epochs) / (self.total_epochs - 10 - self.warmup_epochs)))
        else:
            lr = self.final_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

# Set parameters
num_epochs = 610
warmup_epochs = 5
cooldown_epochs = 10
base_lr = 0.1
final_lr = 1e-5

# Prepare data loaders (example using CIFAR10)
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)

# Initialize model, optimizer, and scheduler
model = models.resnet50(pretrained=False)
model = model.cuda()
optimizer = FusedLAMB(model.parameters(), lr=base_lr)
scheduler = CustomLRScheduler(optimizer, warmup_epochs, num_epochs - cooldown_epochs, base_lr, final_lr)

criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(num_epochs):
    model.train()
    scheduler.step(epoch)
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, LR: {scheduler.get_lr():.6f}')

print('Training complete')


# This

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import timm
from timm.data import create_transform
from timm.data.mixup import Mixup
from timm.data.auto_augment import rand_augment_transform
from timm.data.transforms import str_to_interp_mode

# For LAMB optimizer
from apex.optimizers import FusedLAMB

# Data Pipeline
batch_size = 192 * 8
augmentation_factor = 3

transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0), interpolation=str_to_interp_mode('random')),
    transforms.RandomHorizontalFlip(),
    rand_augment_transform('rand-m7-n3-mstd0.5', {}),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageNet(root='./data', split='train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Mixup and Cutmix
mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, cutmix_minmax=None, prob=1.0, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=1000)


In [None]:
import math

# Custom learning rate scheduler
class CustomLRScheduler:
    def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr, final_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.base_lr = base_lr
        self.final_lr = final_lr
        self.current_epoch = 0

    def step(self, epoch):
        self.current_epoch = epoch
        if epoch < self.warmup_epochs:
            lr = self.base_lr * (epoch / self.warmup_epochs)
        elif epoch < self.total_epochs - 10:
            lr = self.final_lr + (self.base_lr - self.final_lr) * 0.5 * (1 + math.cos(math.pi * (epoch - self.warmup_epochs) / (self.total_epochs - 10 - self.warmup_epochs)))
        else:
            lr = self.final_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

# Initialize model, optimizer, and scheduler
model = models.resnet50(pretrained=False)
model = model.cuda()

optimizer = FusedLAMB(model.parameters(), lr=5e-3)
scheduler = CustomLRScheduler(optimizer, warmup_epochs=5, total_epochs=610, base_lr=5e-3, final_lr=1e-5)

# Loss functions
criterion_kd = nn.KLDivLoss()
criterion_ce = nn.CrossEntropyLoss()

# Training loop
num_epochs = 610
for epoch in range(num_epochs):
    model.train()
    scheduler.step(epoch)
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = mixup_fn(inputs, targets)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Combining losses
        loss = 0.8 * criterion_kd(outputs, targets) + 0.2 * criterion_ce(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, LR: {scheduler.get_lr():.6f}')

print('Training complete')


# Modified Training Script with EMA

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import timm
from timm.data import create_transform
from timm.data.mixup import Mixup
from timm.data.auto_augment import rand_augment_transform
from timm.data.transforms import str_to_interp_mode
from timm.utils import ModelEmaV2

# For LAMB optimizer
from apex.optimizers import FusedLAMB

# Data Pipeline
batch_size = 192 * 8
augmentation_factor = 3

transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0), interpolation=str_to_interp_mode('random')),
    transforms.RandomHorizontalFlip(),
    rand_augment_transform('rand-m7-n3-mstd0.5', {}),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageNet(root='./data', split='train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Mixup and Cutmix
mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, cutmix_minmax=None, prob=1.0, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=1000)

import math

# Custom learning rate scheduler
class CustomLRScheduler:
    def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr, final_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.base_lr = base_lr
        self.final_lr = final_lr
        self.current_epoch = 0

    def step(self, epoch):
        self.current_epoch = epoch
        if epoch < self.warmup_epochs:
            lr = self.base_lr * (epoch / self.warmup_epochs)
        elif epoch < self.total_epochs - 10:
            lr = self.final_lr + (self.base_lr - self.final_lr) * 0.5 * (1 + math.cos(math.pi * (epoch - self.warmup_epochs) / (self.total_epochs - 10 - self.warmup_epochs)))
        else:
            lr = self.final_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

# Initialize model, optimizer, and scheduler
model = models.resnet50(pretrained=False)
model = model.cuda()

# EMA Model Wrapper
model_ema = ModelEmaV2(model, decay=0.9999, device='cuda')

optimizer = FusedLAMB(model.parameters(), lr=5e-3)
scheduler = CustomLRScheduler(optimizer, warmup_epochs=5, total_epochs=610, base_lr=5e-3, final_lr=1e-5)

# Loss functions
criterion_kd = nn.KLDivLoss()
criterion_ce = nn.CrossEntropyLoss()

# Training loop
num_epochs = 610
for epoch in range(num_epochs):
    model.train()
    scheduler.step(epoch)
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = mixup_fn(inputs, targets)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Combining losses
        loss = 0.8 * criterion_kd(outputs, targets) + 0.2 * criterion_ce(outputs, targets)
        loss.backward()
        optimizer.step()

        # Update EMA model
        model_ema.update(model)

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, LR: {scheduler.get_lr():.6f}')

print('Training complete')

# Use EMA model for evaluation
ema_model = model_ema.module


In [None]:
from timm.utils import ModelEmaV2
model_ema = ModelEmaV2(model, decay=0.9999, device='cuda')


# Update the EMA model during training:

In [None]:
model_ema.update(model)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import timm
from timm.data import create_transform
from timm.data.mixup import Mixup
from timm.data.auto_augment import rand_augment_transform
from timm.data.transforms import str_to_interp_mode
from timm.utils import ModelEmaV2

# For LAMB optimizer
from apex.optimizers import FusedLAMB

# Data Pipeline
batch_size = 192 * 8
augmentation_factor = 3

transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0), interpolation=str_to_interp_mode('random')),
    transforms.RandomHorizontalFlip(),
    rand_augment_transform('rand-m7-n3-mstd0.5', {}),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageNet(root='./data', split='train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Mixup and Cutmix
mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, cutmix_minmax=None, prob=1.0, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=1000)

import math

# Custom learning rate scheduler
class CustomLRScheduler:
    def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr, final_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.base_lr = base_lr
        self.final_lr = final_lr
        self.current_epoch = 0

    def step(self, epoch):
        self.current_epoch = epoch
        if epoch < self.warmup_epochs:
            lr = self.base_lr * (epoch / self.warmup_epochs)
        elif epoch < self.total_epochs - 10:
            lr = self.final_lr + (self.base_lr - self.final_lr) * 0.5 * (1 + math.cos(math.pi * (epoch - self.warmup_epochs) / (self.total_epochs - 10 - self.warmup_epochs)))
        else:
            lr = self.final_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

# Initialize model, optimizer, and scheduler
model = models.resnet50(pretrained=False)
model = model.cuda()

# EMA Model Wrapper
model_ema = ModelEmaV2(model, decay=0.9999, device='cuda')

optimizer = FusedLAMB(model.parameters(), lr=5e-3)
scheduler = CustomLRScheduler(optimizer, warmup_epochs=5, total_epochs=610, base_lr=5e-3, final_lr=1e-5)

# Loss functions
criterion_kd = nn.KLDivLoss()
criterion_ce = nn.CrossEntropyLoss()

# Training loop
num_epochs = 610
for epoch in range(num_epochs):
    model.train()
    scheduler.step(epoch)
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = mixup_fn(inputs, targets)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Combining losses
        loss = 0.8 * criterion_kd(outputs, targets) + 0.2 * criterion_ce(outputs, targets)
        loss.backward()
        optimizer.step()

        # Update EMA model
        model_ema.update(model)

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, LR: {scheduler.get_lr():.6f}')

print('Training complete')

# Use EMA model for evaluation
ema_model = model_ema.module

# Evaluation function
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    accuracy = correct / total
    return accuracy

# DataLoader for the validation set
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_dataset = datasets.ImageNet(root='./data', split='val', transform=val_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Evaluate the EMA model
ema_model_accuracy = evaluate(ema_model, val_loader)
print(f'EMA Model Accuracy: {ema_model_accuracy:.4f}')
