In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import torchvision.models as models
import random
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
random.seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data = 'fashMNIST'

In [None]:
model_names = ['resnet101','resnet152']

def load_model(model_name):
    model = getattr(models, model_name)(pretrained=True)
    if 'resnet' in model_name or 'resnext' in model_name:
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 10)
    elif 'vgg' in model_name or 'alexnet' in model_name:
        model.classifier[6] = nn.Linear(4096, 10)
    elif 'shufflenet' in model_name:
        model.fc = nn.Linear(1024, 10)
    elif 'mnasnet' in model_name:
        model.classifier[1] = nn.Linear(1280, 10)
    elif 'densenet' in model_name:
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, 10)
    elif 'squeezenet' in model_name:
        model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))
        model.num_classes = 10
    elif 'mobilenet' in model_name:
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 10)
    elif 'googlenet' in model_name or 'inception' in model_name:
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 10)
    elif 'efficientnet' in model_name:
        num_features = model._fc.in_features
        model._fc = nn.Linear(num_features, 10)
    elif 'convnext' in model_name:
        num_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(num_features, 10)
    else:
        print('Unsupported model')
    return model.to(device)

In [None]:
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

if data == 'fashMNIST':
    training_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
elif data == 'MNIST':
    training_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
else:
    print("data not selected")
train_size = int(0.8 * len(training_dataset))
val_size = len(training_dataset) - train_size
train_dataset, val_dataset = random_split(training_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
num_epochs = 50

for model_name in model_names:
    print(f"Training {model_name}...")
    model = load_model(model_name)

    best_val_loss = float('inf')
    patience = 5
    counter = 0

    best_val_accuracy = 0.0
    best_model = None

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        train_accuracy = 100 * correct / total

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        prediction_tensors = []
        label_tensors = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                prediction_tensors.append(outputs.cpu())
                label_tensors.append(labels.cpu())

        val_accuracy = 100 * correct / total

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f'Validation loss did not improve for {patience} epochs. Early stopping...')
                break
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {val_accuracy:.2f}%')

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model = model.state_dict()

    print(f"Training {model_name} completed.")

    if best_model is not None:
        torch.save(best_model, f'{model_name}_best_model.pth')
        
        model.load_state_dict(best_model)

    prediction_tensors = torch.cat(prediction_tensors)
    label_tensors = torch.cat(label_tensors)

    result_dict = {'predictionVectors': prediction_tensors, 'labelVectors': label_tensors}
    result_file = f'{model_name}.pt'
    torch.save(result_dict, result_file)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_accuracy = 100 * correct / total
    print(f'Test Accuracy for {model_name}: {test_accuracy:.2f}%')

    del model, criterion, optimizer, prediction_tensors, label_tensors
    torch.cuda.empty_cache()