# 1. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
import time
import timm
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# 2. Load Data

In [None]:
# Define transformations
# Transforms
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

# Load datasets
train_dir = 'output_dataset/train'
val_dir = 'output_dataset/val'
test_dir = 'output_dataset/test'

# Hyperparameters
MODEL_NAME = 'swinv2_tiny_window8_256'
IMAGE_SIZE = 256
NUM_CLASSES = 4
BATCH_SIZE = 32
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 0.01


train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset = ImageFolder(root=val_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Dataset sizes
dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset)
}

# Data loaders dictionary
dataloaders = {
    'train': train_loader,
    'val': val_loader
}


# 3. Define the Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained Swin-Tiny model with classification head for 4 classes
model = timm.create_model('swinv2_tiny_window8_256', pretrained=True, num_classes=4)

# Move model to device
model = model.to(device)

# 4. Fine-tune the Model

In [None]:
# Define loss function and optimizer
LABEL_SMOOTHING = 0.1
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 25
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

scaler = GradScaler()  # For AMP

# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # Initialize history dictionary to store metrics
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch + 1}/{num_epochs}')

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'), autocast():
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update() 
                    

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc * 100:.4f}')

            # Store in history
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())  # Convert tensor to float

            # Save best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc * 100:.4f}')

    model.load_state_dict(best_model_wts)
    return model, history

# Train the model
model, history = train_model(model, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS)



Epoch 1/1
train Loss: 1.3544 Acc: 42.8571
val Loss: 1.2412 Acc: 50.0000

Training complete in 2m 1s
Best val Acc: 50.0000


# 5. Save the Model

In [8]:
# Save the trained model
torch.save(model.state_dict(), 'swin_t_baseline.pth')

# 6. Prune the Model

### 6.1 Load pretrained baseline model

In [None]:
import torch
import torch.nn as nn
import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model without loading pretrained weights from ImageNet (we'll load our own)
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4)

# Load your own saved trained weights
model.load_state_dict(torch.load('swin_t_baseline.pth'))

# Move model to device
model = model.to(device)


### 6.2 Prunning configuration with 30% Weights pruned

In [12]:
import torch.nn.utils.prune as prune

def prune_and_remove(model, amount=0.3):
    """
    Prunes 30% of weights in all Conv2d and Linear layers, then makes it permanent.
    """
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # Apply unstructured L1 pruning
            prune.l1_unstructured(module, name='weight', amount=amount)
            # Remove the pruning mask and make it permanent
            prune.remove(module, 'weight')
    return model

pruned_model = prune_and_remove(model, amount=0.3)



### 6.3 Train pruned model and save it

In [None]:
# Define loss function and optimizer for pruned model
LABEL_SMOOTHING = 0.1
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 10  # or adjust as needed
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer = torch.optim.AdamW(pruned_model.parameters(), lr=1e-4, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler for pruned model
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

scaler = GradScaler()  # For AMP

# Training function remains the same
pruned_model, pruned_history = train_model(pruned_model, criterion, optimizer, scheduler, num_epochs=10)

# Save pruned model
torch.save(pruned_model.state_dict(), 'swin_t_pruned.pth')



Epoch 1/1
train Loss: 1.2754 Acc: 50.2381
val Loss: 1.2380 Acc: 50.0000

Training complete in 2m 2s
Best val Acc: 50.0000


# 7. Evaluate the Model

In [15]:
# Define evaluation function
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Evaluate baseline model
model.load_state_dict(torch.load('swin_t_baseline.pth'))
baseline_accuracy = evaluate_model(model, test_loader)
print(f'Baseline Model Accuracy: {baseline_accuracy}%')

# Evaluate pruned model
model.load_state_dict(torch.load('swin_t_pruned.pth'))
pruned_accuracy = evaluate_model(model, test_loader)
print(f'Pruned Model Accuracy: {pruned_accuracy}%')

Baseline Model Accuracy: 49.45054945054945%
Pruned Model Accuracy: 49.45054945054945%
