In [1]:
!pip install timm accelerate tqdm

Collecting timm
  Downloading timm-1.0.7-py3-none-any.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.5/47.5 kB[0m [31m841.5 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hCollecting accelerate
  Downloading accelerate-0.32.1-py3-none-any.whl.metadata (18 kB)
Collecting huggingface_hub (from timm)
  Downloading huggingface_hub-0.23.4-py3-none-any.whl.metadata (12 kB)
Collecting safetensors (from timm)
  Downloading safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading timm-1.0.7-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m214.2 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading accelerate-0.32.1-py3-none-any.whl (314 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.1/314.1 kB[0m [31m774.3 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading safetensors-0.4.3-cp310-cp310-manyl

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from timm import create_model
from accelerate import Accelerator
from tqdm import tqdm

class SmoothedLoss(nn.Module):
    def __init__(self):
        super(SmoothedLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')

    def forward(self, outputs, targets):
        # Calculate individual cross-entropy losses
        ce_loss = self.ce_loss(outputs, targets)
        
        # Get probabilities of the correct classes
        probs = torch.softmax(outputs, dim=1)
        highest_probs, predicted = torch.max(probs, dim=1)
        
        # Check if predictions are correct
        correct_preds = (predicted == targets)
        
        # Calculate the quadratic multiplier
        correct_multiplier = -(highest_probs / 2) ** 2 + 1
        incorrect_multiplier = (highest_probs / 2) ** 2 + 1
        
        # Apply the multipliers
        adjustment = torch.ones_like(ce_loss)
        adjustment[correct_preds] = correct_multiplier[correct_preds]
        adjustment[~correct_preds] = incorrect_multiplier[~correct_preds]
        adjusted_loss = ce_loss * adjustment
        
        # Return the mean loss over the batch
        return adjusted_loss.mean()

class TinyNet(nn.Module):
    def __init__(self, num_classes):
        super(TinyNet, self).__init__()
        self.model = create_model('tinynet_e', pretrained=True, num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x)

def train_and_evaluate(model, criterion, optimizer, train_loader, val_loader, accelerator, num_epochs=50):
    best_val_accuracy = 0
    patience = 5  # Number of epochs to wait if no improvement is seen
    early_stopping_counter = 0
    best_epoch = 0

    for epoch in range(num_epochs):
        model.train()
        run_loss = 0
        correct = 0
        total = 0
        confidence_sum = 0
        
        for inputs, targets in (train_loader):
            inputs, targets = inputs.to(accelerator.device), targets.to(accelerator.device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            run_loss += loss.item() * inputs.size(0)
            
            probs = torch.softmax(outputs, dim=1)
            correct_probs = probs[range(len(targets)), targets]
            confidence_sum += correct_probs.sum().item()
            
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
        
        train_loss = run_loss / total
        train_accuracy = correct / total
        average_train_confidence = confidence_sum / total

        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        confidence_sum = 0
        best_v_l = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(accelerator.device), targets.to(accelerator.device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                
                probs = torch.softmax(outputs, dim=1)
                correct_probs = probs[range(len(targets)), targets]
                confidence_sum += correct_probs.sum().item()
                
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        val_loss /= total
        val_accuracy = correct / total
        average_val_confidence = confidence_sum / total
        
        print(f"Epoch {epoch + 1}/{num_epochs}, "
              f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train Avg Confidence: {average_train_confidence:.4f}, "
              f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Avg Confidence: {average_val_confidence:.4f}")
        
        # Early stopping criteria
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            
            best_epoch = epoch + 1
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print(f'Early stopping triggered after {patience} epochs of no improvement.')
                break

    return val_loss, best_val_accuracy, average_val_confidence

def main():
    accelerator = Accelerator(mixed_precision='fp16')
    device = accelerator.device

    # Hyperparameters
    num_classes = 100
    batch_size = 2048
    num_epochs = 100 

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness = 0.1, contrast = 0.1, saturation = 0.1, hue = 0.1)
    ])
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loader, val_loader = accelerator.prepare(train_loader, val_loader)


    print(f"Smoothed Loss")
    
    torch.manual_seed(0)
    model = TinyNet(num_classes)
    criterion = SmoothedLoss()
    optimizer = optim.Adam(model.parameters())
    
    model, optimizer, criterion = accelerator.prepare(model, optimizer, criterion)
    
    smoothed_loss, smoothed_accuracy, smoothed_confidence = train_and_evaluate(model, criterion, optimizer, train_loader, val_loader, accelerator, num_epochs)

    print(f"Cross Entropy Loss")
    
    torch.manual_seed(0)
    model = TinyNet(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    
    model, optimizer, criterion = accelerator.prepare(model, optimizer, criterion)
    
    ce_loss, ce_accuracy, ce_confidence = train_and_evaluate(model, criterion, optimizer, train_loader, val_loader, accelerator, num_epochs)


    print(" Validation Accuracy with Smoothed Loss:", smoothed_accuracy)
    print(" Validation Confidence with Smoothed Loss:", smoothed_confidence)
    print(" Validation Accuracy with Cross Entropy Loss:", ce_accuracy)
    print(" Validation Confidence with Cross Entropy Loss:", ce_confidence)

if __name__ == "__main__":
    main()


Files already downloaded and verified
Files already downloaded and verified
Smoothed Loss
Epoch 1/100, Train Loss: 4.6124, Train Accuracy: 0.0616, Train Avg Confidence: 0.0261, Validation Loss: 3.9366, Validation Accuracy: 0.1261, Validation Avg Confidence: 0.0636
Epoch 2/100, Train Loss: 3.5959, Train Accuracy: 0.1674, Train Avg Confidence: 0.0793, Validation Loss: 3.4168, Validation Accuracy: 0.1976, Validation Avg Confidence: 0.1094
Epoch 3/100, Train Loss: 3.1893, Train Accuracy: 0.2330, Train Avg Confidence: 0.1193, Validation Loss: 3.1539, Validation Accuracy: 0.2434, Validation Avg Confidence: 0.1405
Epoch 4/100, Train Loss: 2.9423, Train Accuracy: 0.2762, Train Avg Confidence: 0.1516, Validation Loss: 3.0006, Validation Accuracy: 0.2735, Validation Avg Confidence: 0.1560
Epoch 5/100, Train Loss: 2.7844, Train Accuracy: 0.3049, Train Avg Confidence: 0.1743, Validation Loss: 2.9020, Validation Accuracy: 0.2896, Validation Avg Confidence: 0.1706
Epoch 6/100, Train Loss: 2.6549, Tr