<a href="https://colab.research.google.com/github/NiveditaS22/restnet/blob/main/Task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import StepLR
import torch.nn.utils.prune as prune
import numpy as np

# Define the ResNet18 Model with Pretrained Weights
class ResNetCIFAR10(nn.Module):
    def __init__(self):
        super(ResNetCIFAR10, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 10)

    def forward(self, x):
        return self.resnet(x)

# Function to initialize and return data loaders
def get_data_loaders(batch_size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Load the full training dataset
    full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                 download=True, transform=transform_train)
    # Split into training and validation datasets
    train_size = int(0.8 * len(full_trainset))
    val_size = len(full_trainset) - train_size
    trainset, valset = torch.utils.data.random_split(full_trainset, [train_size, val_size])

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    return trainloader, valloader, testloader

# Function to train the model
def train(model, trainloader, criterion, optimizer, scaler, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        if batch_idx % 100 == 99:
            print(f'Epoch {epoch + 1}, Batch {batch_idx + 1}, Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# Function to validate the model
def validate(model, valloader, criterion):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in valloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy:.2f}%')
    return accuracy

# Function to prune the model
def prune_model(model, amount=0.2):
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))

    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=amount)
    return model

# Function to evaluate p50 and p90 performance
def evaluate_p_performance(model, testloader, criterion):
    model.eval()
    all_losses = []
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            all_losses.append(loss.item())
    all_losses = np.array(all_losses)
    p50 = np.percentile(all_losses, 50)
    p90 = np.percentile(all_losses, 90)
    print(f'p50 Performance: {p50:.4f}, p90 Performance: {p90:.4f}')
    return p50, p90

# Main function to setup and execute the training and validation
def main():
    global device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    try:
        model = ResNetCIFAR10().to(device)
        trainloader, valloader, testloader = get_data_loaders(batch_size=128)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
        scaler = GradScaler()

        num_epochs = 10
        best_accuracy = 0.0

        for epoch in range(num_epochs):
            train(model, trainloader, criterion, optimizer, scaler, epoch)
            accuracy = validate(model, valloader, criterion)
            scheduler.step()

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                torch.save(model.state_dict(), 'model.pth')
                print(f"Model saved with accuracy: {accuracy:.2f}%")

            if accuracy >= 80:
                print(f"Reached 80% accuracy at epoch {epoch + 1}")
                break

        # Save the original trained model
        torch.save(model.state_dict(), 'model.pth')
        print("Original trained model saved as model.pth")

        # Prune the model
        prune_ratios = [0.5, 0.7]  # Experimented with different pruning ratios
        for ratio in prune_ratios:
            pruned_model = prune_model(model, amount=ratio)
            accuracy = validate(pruned_model, valloader, criterion)
            torch.save(pruned_model.state_dict(), f'pruned_model_{int(ratio*100)}.pth')
            print(f"Pruned model saved with {int(ratio*100)}% pruning ratio. Validation Accuracy: {accuracy:.2f}%")

        print("Evaluating p50 and p90 performance before pruning")
        evaluate_p_performance(model, testloader, criterion)

        # Prune the model
        pruned_model = prune_model(model)
        torch.save(pruned_model.state_dict(), 'pruned_model.pth')
        print("Pruned model saved as pruned_model.pth")

    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == '__main__':
    main()



Using device: cuda:0


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 181MB/s]


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 30397191.54it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


  return F.conv2d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 1, Batch 100, Loss: 1.365
Epoch 1, Batch 200, Loss: 0.989
Epoch 1, Batch 300, Loss: 0.886
Validation Accuracy: 70.71%
Model saved with accuracy: 70.71%
Epoch 2, Batch 100, Loss: 0.821
Epoch 2, Batch 200, Loss: 0.754
Epoch 2, Batch 300, Loss: 0.740
Validation Accuracy: 75.67%
Model saved with accuracy: 75.67%
Epoch 3, Batch 100, Loss: 0.692
Epoch 3, Batch 200, Loss: 0.674
Epoch 3, Batch 300, Loss: 0.669
Validation Accuracy: 76.44%
Model saved with accuracy: 76.44%
Epoch 4, Batch 100, Loss: 0.607
Epoch 4, Batch 200, Loss: 0.640
Epoch 4, Batch 300, Loss: 0.612
Validation Accuracy: 77.39%
Model saved with accuracy: 77.39%
Epoch 5, Batch 100, Loss: 0.589
Epoch 5, Batch 200, Loss: 0.568
Epoch 5, Batch 300, Loss: 0.570
Validation Accuracy: 79.35%
Model saved with accuracy: 79.35%
Epoch 6, Batch 100, Loss: 0.540
Epoch 6, Batch 200, Loss: 0.549
Epoch 6, Batch 300, Loss: 0.532
Validation Accuracy: 79.57%
Model saved with accuracy: 79.57%
Epoch 7, Batch 100, Loss: 0.502
Epoch 7, Batch 200, 