<a href="https://colab.research.google.com/github/MammadovN/Machine_Learning/blob/main/projects/03_deep_learning/transfer_learning/transfer_learning_classification/transfer_learning_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [2]:
def get_data_loaders(batch_size=32):
    """
    Uses CIFAR-10 as the dataset, applying necessary transforms.
    Returns dataloaders dict, dataset sizes, and class names.
    """
    # Data augmentation and normalization for training
    # Just normalization for validation/test
    train_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224, padding=4),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    train_ds = datasets.CIFAR10(root='./data', train=True,
                                download=True, transform=train_transform)
    val_ds   = datasets.CIFAR10(root='./data', train=False,
                                download=True, transform=test_transform)
    # For simplicity, use val_ds as test set as well
    dataloaders = {
        'train': DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2),
        'val':   DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2),
        'test':  DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2)
    }
    sizes = {'train': len(train_ds), 'val': len(val_ds), 'test': len(val_ds)}
    class_names = train_ds.classes
    return dataloaders, sizes, class_names

In [3]:
def build_model(num_classes, feature_extract=True):
    """
    Loads pretrained ResNet50, freezes conv layers if feature_extract,
    and replaces the final FC layer for num_classes.
    """
    model = models.resnet50(pretrained=True)
    if feature_extract:
        for param in model.parameters():
            param.requires_grad = False
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

In [4]:
def train_model(model, dataloaders, sizes, criterion, optimizer, device, num_epochs=10):
    """
    Trains the model, returns best model (by val accuracy).
    """
    best_weights = model.state_dict()
    best_acc = 0.0
    model.to(device)

    for epoch in range(1, num_epochs+1):
        print(f"Epoch {epoch}/{num_epochs}")
        for phase in ['train', 'val']:
            model.train() if phase=='train' else model.eval()
            running_loss, running_corrects = 0.0, 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds==labels.data)

            epoch_loss = running_loss / sizes[phase]
            epoch_acc = running_corrects.double() / sizes[phase]
            print(f" {phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase=='val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_weights = model.state_dict()

    model.load_state_dict(best_weights)
    return model

In [5]:
def evaluate_model(model, dataloader, device):
    """
    Evaluates on test set and returns accuracy.
    """
    model.eval()
    corrects, total = 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            corrects += torch.sum(preds==labels.data).item()
            total += labels.size(0)
    return corrects / total

In [6]:
def main():
    # Hyperparameters
    epochs = 10
    batch_size = 32
    lr = 1e-3
    feature_extract = True
    output_path = 'best_resnet50.pth'

    # Prepare data
    dataloaders, sizes, class_names = get_data_loaders(batch_size=batch_size)

    # Build model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = build_model(len(class_names), feature_extract=feature_extract)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    # Train & evaluate
    best_model = train_model(model, dataloaders, sizes, criterion, optimizer, device, num_epochs=epochs)
    test_acc = evaluate_model(best_model, dataloaders['test'], device)
    print(f"Test Accuracy: {test_acc:.4f}")

    # Save
    torch.save({'model_state_dict': best_model.state_dict(),
                'class_names': class_names}, output_path)


In [7]:
if __name__ == '__main__':
    main()

100%|██████████| 170M/170M [00:13<00:00, 12.4MB/s]
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 200MB/s]


Epoch 1/10
 train Loss: 0.7419 Acc: 0.7545
 val Loss: 0.6188 Acc: 0.7914
Epoch 2/10
 train Loss: 0.6098 Acc: 0.7928
 val Loss: 0.5712 Acc: 0.8035
Epoch 3/10
 train Loss: 0.5901 Acc: 0.7989
 val Loss: 0.5523 Acc: 0.8124
Epoch 4/10
 train Loss: 0.5704 Acc: 0.8061
 val Loss: 0.5273 Acc: 0.8214
Epoch 5/10
 train Loss: 0.5532 Acc: 0.8118
 val Loss: 0.5663 Acc: 0.8145
Epoch 6/10
 train Loss: 0.5423 Acc: 0.8151
 val Loss: 0.5450 Acc: 0.8185
Epoch 7/10
 train Loss: 0.5324 Acc: 0.8173
 val Loss: 0.4996 Acc: 0.8305
Epoch 8/10
 train Loss: 0.5262 Acc: 0.8213
 val Loss: 0.5201 Acc: 0.8216
Epoch 9/10
 train Loss: 0.5309 Acc: 0.8220
 val Loss: 0.5860 Acc: 0.8014
Epoch 10/10
 train Loss: 0.5203 Acc: 0.8223
 val Loss: 0.5394 Acc: 0.8194
Test Accuracy: 0.8194
