To be up-to-date on the most current version of this code. Check out my GitHub repository: https://github.com/Neatherblok/SnowDetection

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

## Data Preparation

In [2]:
from Data_Preparation.Preparation import CustomDataLoader

In [3]:
# training data properties
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
BATCH_SIZE = 4

In [4]:
# Instantiate the CustomDataLoader class for training
train_data_loader = CustomDataLoader(data_path="./data", batch_size=BATCH_SIZE, dataset_type="train", mean=MEAN, std=STD).data_loader
val_data_loader = CustomDataLoader(data_path="./data", batch_size=BATCH_SIZE, dataset_type="val", mean=MEAN, std=STD).data_loader
test_data_loader = CustomDataLoader(data_path="./data", batch_size=BATCH_SIZE, dataset_type="test", mean=MEAN, std=STD).data_loader

image_datasets = {'train':train_data_loader.dataset, 'val':val_data_loader.dataset, 'test':test_data_loader.dataset}
dataloaders = {'train':train_data_loader, 'val':val_data_loader, 'test':test_data_loader}

## Initializing VGG19 and ResNet50 Finetuning

In [28]:
# Load pre-trained models
vgg19 = models.vgg19(pretrained=True)
resnet50 = models.resnet50(pretrained=True)

In [29]:
# Freeze parameters so we don't backprop through them
for param in vgg19.parameters():
    param.requires_grad = False
for param in resnet50.parameters():
    param.requires_grad = False

In [30]:
# Replace the classifier with a new one
num_classes = len(train_data_loader.dataset.classes)
vgg19.classifier[6] = nn.Linear(4096, num_classes)
resnet50.fc = nn.Linear(2048, num_classes)

In [31]:
LR = 0.0001
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_vgg19 = optim.Adam(vgg19.parameters(), lr=LR)
optimizer_resnet50 = optim.Adam(resnet50.parameters(), lr=LR)

## Training VGG19 and ResNet50

In [32]:
from sklearn.metrics import f1_score

# Train the models
num_epochs = 25
for epoch in tqdm(range(num_epochs)):
    for phase in ['train', 'val']:
        if phase == 'train':
            vgg19.train()
            resnet50.train()
        else:
            vgg19.eval()
            resnet50.eval()

        running_loss = 0.0
        corrects = 0
        all_preds_vgg19 = []
        all_preds_resnet50 = []
        all_labels = []

        torch.manual_seed(2809)
        for inputs, labels in dataloaders[phase]:
            optimizer_vgg19.zero_grad()
            optimizer_resnet50.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs_vgg19 = vgg19(inputs)
                outputs_resnet50 = resnet50(inputs)
                _, preds_vgg19 = torch.max(outputs_vgg19, 1)
                _, preds_resnet50 = torch.max(outputs_resnet50, 1)

                loss_vgg19 = criterion(outputs_vgg19, labels)
                loss_resnet50 = criterion(outputs_resnet50, labels)

                if phase == 'train':
                    loss_vgg19.backward()
                    loss_resnet50.backward()
                    optimizer_vgg19.step()
                    optimizer_resnet50.step()

            running_loss += loss_vgg19.item() * inputs.size(0)
            running_loss += loss_resnet50.item() * inputs.size(0)
            corrects += torch.sum(preds_vgg19 == labels.data)
            corrects += torch.sum(preds_resnet50 == labels.data)
            all_preds_vgg19.extend(preds_vgg19.cpu().numpy())
            all_preds_resnet50.extend(preds_resnet50.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = corrects.double() / len(image_datasets[phase])
        epoch_f1_vgg19 = f1_score(all_labels, all_preds_vgg19, average='macro')
        epoch_f1_resnet50 = f1_score(all_labels, all_preds_resnet50, average='macro')

        print('{} Loss: {:.4f} | Acc: {:.4f} | F1 VGG19: {:.4f} | F1 ResNet50: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_f1_vgg19, epoch_f1_resnet50))

    # Save model every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save(vgg19.state_dict(), f'models/lr_{LR}/vgg19_epoch_{epoch+1}.pt')
        torch.save(resnet50.state_dict(), f'models/lr_{LR}/resnet50_epoch_{epoch+1}.pt')


  0%|                                                                                           | 0/25 [00:00<?, ?it/s]

train Loss: 35.8963 | Acc: 1.3205 | F1 VGG19: 0.7820 | F1 ResNet50: 0.5357


  4%|███▎                                                                               | 1/25 [00:22<09:01, 22.56s/it]

val Loss: 2.4773 | Acc: 1.8889 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8875
train Loss: 18.4936 | Acc: 1.6282 | F1 VGG19: 0.8717 | F1 ResNet50: 0.7564


  8%|██████▋                                                                            | 2/25 [00:44<08:36, 22.44s/it]

val Loss: 22.7621 | Acc: 1.6111 | F1 VGG19: 0.9443 | F1 ResNet50: 0.6250
train Loss: 11.2508 | Acc: 1.7949 | F1 VGG19: 0.9872 | F1 ResNet50: 0.8077


 12%|█████████▉                                                                         | 3/25 [01:08<08:28, 23.10s/it]

val Loss: 18.9377 | Acc: 1.7222 | F1 VGG19: 1.0000 | F1 ResNet50: 0.6990
train Loss: 26.1065 | Acc: 1.6923 | F1 VGG19: 1.0000 | F1 ResNet50: 0.6921


 16%|█████████████▎                                                                     | 4/25 [01:35<08:31, 24.35s/it]

val Loss: 7.5678 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8286
train Loss: 28.3414 | Acc: 1.6923 | F1 VGG19: 1.0000 | F1 ResNet50: 0.6915
val Loss: 8.2696 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8286


 20%|████████████████▌                                                                  | 5/25 [02:03<08:38, 25.91s/it]

train Loss: 20.0313 | Acc: 1.8077 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8061


 24%|███████████████████▉                                                               | 6/25 [02:29<08:10, 25.84s/it]

val Loss: 25.5840 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7662
train Loss: 7.4544 | Acc: 1.8974 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8974


 28%|███████████████████████▏                                                           | 7/25 [02:53<07:35, 25.30s/it]

val Loss: 3.5855 | Acc: 1.9444 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9443
train Loss: 0.9168 | Acc: 1.9615 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9615


 32%|██████████████████████████▌                                                        | 8/25 [03:23<07:36, 26.83s/it]

val Loss: 13.3565 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8286
train Loss: 0.2648 | Acc: 1.9744 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9744


 36%|█████████████████████████████▉                                                     | 9/25 [03:47<06:56, 26.02s/it]

val Loss: 4.6250 | Acc: 1.8889 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8889
train Loss: 3.8020 | Acc: 1.9231 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9230
val Loss: 11.5685 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8286


 40%|████████████████████████████████▊                                                 | 10/25 [04:19<06:56, 27.79s/it]

train Loss: 0.3508 | Acc: 1.9872 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9872


 44%|████████████████████████████████████                                              | 11/25 [04:43<06:11, 26.56s/it]

val Loss: 12.4608 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8286
train Loss: 1.6779 | Acc: 1.9615 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9615


 48%|███████████████████████████████████████▎                                          | 12/25 [05:07<05:35, 25.79s/it]

val Loss: 4.4295 | Acc: 1.8889 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8875
train Loss: 4.3142 | Acc: 1.9359 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9358


 52%|██████████████████████████████████████████▋                                       | 13/25 [05:32<05:05, 25.47s/it]

val Loss: 6.5616 | Acc: 1.8889 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8875
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000


 56%|█████████████████████████████████████████████▉                                    | 14/25 [05:56<04:34, 24.96s/it]

val Loss: 21.1792 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7662
train Loss: 0.1707 | Acc: 1.9872 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9872
val Loss: 6.6491 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8328


 60%|█████████████████████████████████████████████████▏                                | 15/25 [06:25<04:23, 26.31s/it]

train Loss: 0.1652 | Acc: 1.9872 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9872


 64%|████████████████████████████████████████████████████▍                             | 16/25 [06:50<03:53, 25.94s/it]

val Loss: 5.6308 | Acc: 1.8889 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8889
train Loss: 3.4507 | Acc: 1.9231 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9230


 68%|███████████████████████████████████████████████████████▊                          | 17/25 [07:14<03:22, 25.27s/it]

val Loss: 11.3800 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8286
train Loss: 1.9151 | Acc: 1.9359 | F1 VGG19: 1.0000 | F1 ResNet50: 0.9358


 72%|███████████████████████████████████████████████████████████                       | 18/25 [07:43<03:05, 26.47s/it]

val Loss: 8.5001 | Acc: 1.8889 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8875
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000


 76%|██████████████████████████████████████████████████████████████▎                   | 19/25 [08:09<02:38, 26.42s/it]

val Loss: 6.0333 | Acc: 1.8333 | F1 VGG19: 1.0000 | F1 ResNet50: 0.8328
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000
val Loss: 7.9885 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7750


 80%|█████████████████████████████████████████████████████████████████▌                | 20/25 [08:35<02:11, 26.27s/it]

train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000


 84%|████████████████████████████████████████████████████████████████████▉             | 21/25 [09:04<01:48, 27.15s/it]

val Loss: 8.2448 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7750
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000


 88%|████████████████████████████████████████████████████████████████████████▏         | 22/25 [09:31<01:21, 27.07s/it]

val Loss: 8.2769 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7750
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000


 92%|███████████████████████████████████████████████████████████████████████████▍      | 23/25 [09:56<00:52, 26.41s/it]

val Loss: 8.2810 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7750
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000


 96%|██████████████████████████████████████████████████████████████████████████████▋   | 24/25 [10:21<00:26, 26.06s/it]

val Loss: 8.2815 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7750
train Loss: 0.0000 | Acc: 2.0000 | F1 VGG19: 1.0000 | F1 ResNet50: 1.0000
val Loss: 8.2815 | Acc: 1.7778 | F1 VGG19: 1.0000 | F1 ResNet50: 0.7750


100%|██████████████████████████████████████████████████████████████████████████████████| 25/25 [10:46<00:00, 25.87s/it]


## Ensemble Models

In [33]:
import torch

def adjust_weights_based_on_f1(model1_f1, model2_f1):
    # Calculate inverses of F1 scores
    inverse_model1_f1 = 1 / model1_f1
    inverse_model2_f1 = 1 / model2_f1
    
    # Normalize inverses
    total_inverse = inverse_model1_f1 + inverse_model2_f1
    weight_model1 = inverse_model1_f1 / total_inverse
    weight_model2 = inverse_model2_f1 / total_inverse
    
    return weight_model1, weight_model2

def ensemble_predict(model1, model2, dataloader, model1_f1=None, model2_f1=None):
    ens_predictions = []
    true_label = []
    model1_predictions = []
    model2_predictions = []
    weights = None
    
    # Adjust weights based on F1 scores if provided
    if model1_f1 is not None and model2_f1 is not None:
        weights = adjust_weights_based_on_f1(model1_f1, model2_f1)
    elif weights is None:
        weights = [0.5, 0.5]  # Default weights if not provided
    
    for inputs, labels in dataloader:
        outputs1 = model1(inputs)
        outputs2 = model2(inputs)
        
        # Weighted averaging
        weighted_outputs = (outputs1 * weights[0]) + (outputs2 * weights[1])
        
        ens_predictions.extend(torch.max(weighted_outputs, 1)[1].tolist())
        true_label.extend(labels.tolist())
        model1_predictions.extend(torch.max(outputs1, 1)[1].tolist())
        model2_predictions.extend(torch.max(outputs2, 1)[1].tolist())
        
    return {
        'ensemble_pred': ens_predictions,
        'label': true_label,
        '1_pred': model1_predictions,
        '2_pred': model2_predictions
    }


In [34]:
from sklearn.metrics import f1_score, accuracy_score

# print(ensemble_predictions['label'])
# print(ensemble_predictions['ensemble_pred'])
# print(ensemble_predictions['1_pred'])
# print(ensemble_predictions['2_pred'])

model_predictions = ensemble_predict(vgg19, resnet50, dataloaders['test'])

vgg_f1 = f1_score(model_predictions['label'], model_predictions['1_pred'], average='macro')
resnet_f1 = f1_score(model_predictions['label'], model_predictions['2_pred'], average='macro')
vgg_acc = accuracy_score(model_predictions['label'], model_predictions['1_pred'])
resnet_acc = accuracy_score(model_predictions['label'], model_predictions['2_pred'])

ensemble_predictions = ensemble_predict(vgg19, resnet50, dataloaders['test'], vgg_f1, resnet_f1)
ensemble_f1 = f1_score(ensemble_predictions['label'], ensemble_predictions['ensemble_pred'], average='macro')
ensemble_acc = accuracy_score(ensemble_predictions['label'], ensemble_predictions['ensemble_pred'])

print(f"True values: {model_predictions['label']}")
print(f"Ensemble predictions: {ensemble_predictions['ensemble_pred']}")
print(f"VGG-19 predictions: {model_predictions['1_pred']}")
print(f"ResNet-50 predictions: {model_predictions['2_pred']}")

print("********************************")

print(f'Ensemble F1 score: {ensemble_f1}')
print(f'VGG-19 F1 score: {vgg_f1}')
print(f'ResNet-50 F1 score: {resnet_f1}')

print("********************************")

print(f'Ensemble accuracy: {ensemble_acc}')
print(f'VGG-19 accuracy: {vgg_acc}')
print(f'ResNet-50 accuracy: {resnet_acc}')


True values: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Ensemble predictions: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]
VGG-19 predictions: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1]
ResNet-50 predictions: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1]
********************************
Ensemble F1 score: 0.7179487179487178
VGG-19 F1 score: 0.7684210526315789
ResNet-50 F1 score: 0.7053571428571429
********************************
Ensemble accuracy: 0.7272727272727273
VGG-19 accuracy: 0.7727272727272727
ResNet-50 accuracy: 0.7272727272727273
