In [1]:
!pip install gdown

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
import os
import torchvision
from torchvision import datasets, models, transforms
from torch.optim.lr_scheduler import StepLR


from src.models import ModelResnet152dTimm
from src.dataloader import EuroSatDownloader
from src.training import train_fine_tuning, eval_method, eval_func
from src.dataloader import MiniImageNetDataSet, createEuroSatDataLoaders, EuroSatDownloader
from src.modelvis import visualize_models

plt.ion()

'''
The line cudnn.benchmark = True is typically used in deep learning projects that utilize the CUDA Deep Neural Network (cuDNN) library.
When cudnn.benchmark is set to True, it enables cuDNN to automatically find the best algorithm configuration for the specific 
input sizes and hardware being used. This can result in improved performance during training and inference.By enabling benchmarking, 
cuDNN will run a short benchmarking phase during the first iteration of the model to determine the optimal algorithm configuration. 
This configuration is then cached and used for subsequent iterations, leading to faster execution times.
It's important to note that enabling benchmarking may introduce some overhead during the initial benchmarking phase, so it is typically 
recommended to use it when the input sizes are consistent throughout the training process.
Overall, setting cudnn.benchmark to True can help optimize the performance of deep learning models that use cuDNN.
'''
torch.backends.cudnn.benchmark = True

'''
The line plt.ion() is a function call that activates interactive mode in matplotlib.
When interactive mode is enabled, any plot that is created will be displayed immediately 
and can be updated dynamically. This means that you can modify the plot after it is displayed, 
such as changing the data or adding annotations, and the changes will be reflected in real-time.
'''
plt.ion()

<contextlib.ExitStack at 0x1e5176dc5b0>

In [3]:
print(torchvision.__version__)
print(torch.__version__)

0.9.2
1.8.2


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### Download EuroSat Database

In [5]:
extractDir = '..\\data'
url = "https://zenodo.org/records/7711810/files/EuroSAT_RGB.zip?download=1"
downloader = EuroSatDownloader(url, extractDir)
dataDir = downloader.download()

2024-01-07 05:16:59,776 - INFO - EuroSat database already exists.


# Hyperparameters

In [6]:
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.001
NUM_EPOCHS = 15
BATCH_SIZE = 25
MOMENTUM = 0.09
STEP_SIZE = 5
num_of_classes = 64
EPISODES = 5
GAMMA = 0.5

In [7]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(224),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
}


dataDir = '..\\data\\EuroSAT_RGB'
modelPath = '..\\data\\models\\best_model_ModelResnet152dTimm_FixedTransforms.pth'

In [8]:
print('Please wait patiently, it may take some seconds...')
total_test_acc = 0.0

good_classes = [set() for _ in range(EPISODES)]
bad_classes = set()

for episode in range(EPISODES):
    print(f"\nTraining Episode: {episode + 1}...")
    
    loaded_model = torch.load(modelPath, map_location=device)
    model = ModelResnet152dTimm(64)
    model.load_state_dict(loaded_model)
    model.model.fc = nn.Linear(model.model.fc.in_features, 5).to(device)

    for param in model.parameters():
        param.requires_grad = True
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params = model.parameters(), lr = LEARNING_RATE, momentum = MOMENTUM, weight_decay = WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = STEP_SIZE, gamma = GAMMA)
    dataloaders, class_names, dataset_sizes = createEuroSatDataLoaders(data_transforms, dataDir, split=0.25, batch_size = BATCH_SIZE)
    print(dataset_sizes)
    
    model, _, _ = train_fine_tuning(model, dataloaders, criterion, optimizer, scheduler, num_epochs = NUM_EPOCHS, learning_rate = LEARNING_RATE)
    test_loader = dataloaders['test']
    acc_test, test_loss = eval_method(net=model, data_loader=test_loader)
    total_test_acc += acc_test
    if acc_test < 0.6:
        bad_classes.update(class_names)
    elif acc_test > 0.7:
        good_classes[episode].update(class_names)
    print(class_names)
    print(f"Episode {episode + 1}: Test Accuracy: {acc_test:.4f}")

average_test_acc = total_test_acc / EPISODES
print("\nAverage Test Accuracy:", average_test_acc)

Please wait patiently, it may take some seconds...

Training Episode: 1...


2024-01-07 05:17:03,389 - INFO - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth)


{'train': 25, 'test': 75}
epoch: 0, accuracy: 0.160000, avg. loss: 1.633268, test accuracy: 0.333333
epoch: 1, accuracy: 0.280000, avg. loss: 1.585307, test accuracy: 0.320000
epoch: 2, accuracy: 0.160000, avg. loss: 1.604901, test accuracy: 0.346667
epoch: 3, accuracy: 0.120000, avg. loss: 1.594753, test accuracy: 0.360000
epoch: 4, accuracy: 0.160000, avg. loss: 1.598347, test accuracy: 0.346667
epoch: 5, accuracy: 0.280000, avg. loss: 1.556106, test accuracy: 0.293333
epoch: 6, accuracy: 0.520000, avg. loss: 1.538106, test accuracy: 0.293333
epoch: 7, accuracy: 0.240000, avg. loss: 1.539082, test accuracy: 0.346667
epoch: 8, accuracy: 0.320000, avg. loss: 1.513519, test accuracy: 0.333333
epoch: 9, accuracy: 0.400000, avg. loss: 1.482218, test accuracy: 0.333333
epoch: 10, accuracy: 0.520000, avg. loss: 1.478004, test accuracy: 0.320000
epoch: 11, accuracy: 0.320000, avg. loss: 1.508707, test accuracy: 0.346667
epoch: 12, accuracy: 0.280000, avg. loss: 1.512326, test accuracy: 0.346

2024-01-07 05:20:13,331 - INFO - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth)


{'train': 25, 'test': 75}
epoch: 0, accuracy: 0.160000, avg. loss: 1.642227, test accuracy: 0.186667
epoch: 1, accuracy: 0.080000, avg. loss: 1.674231, test accuracy: 0.160000
epoch: 2, accuracy: 0.240000, avg. loss: 1.620021, test accuracy: 0.160000
epoch: 3, accuracy: 0.120000, avg. loss: 1.654762, test accuracy: 0.173333
epoch: 4, accuracy: 0.160000, avg. loss: 1.639924, test accuracy: 0.186667
epoch: 5, accuracy: 0.160000, avg. loss: 1.602441, test accuracy: 0.173333
epoch: 6, accuracy: 0.160000, avg. loss: 1.595123, test accuracy: 0.213333
epoch: 7, accuracy: 0.080000, avg. loss: 1.628648, test accuracy: 0.240000
epoch: 8, accuracy: 0.080000, avg. loss: 1.599404, test accuracy: 0.226667
epoch: 9, accuracy: 0.360000, avg. loss: 1.540437, test accuracy: 0.240000
epoch: 10, accuracy: 0.240000, avg. loss: 1.553781, test accuracy: 0.200000
epoch: 11, accuracy: 0.240000, avg. loss: 1.546798, test accuracy: 0.186667
epoch: 12, accuracy: 0.200000, avg. loss: 1.561395, test accuracy: 0.186

2024-01-07 05:23:36,000 - INFO - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth)


{'train': 25, 'test': 75}
epoch: 0, accuracy: 0.280000, avg. loss: 1.565910, test accuracy: 0.106667
epoch: 1, accuracy: 0.280000, avg. loss: 1.502577, test accuracy: 0.120000
epoch: 2, accuracy: 0.320000, avg. loss: 1.511392, test accuracy: 0.173333
epoch: 3, accuracy: 0.320000, avg. loss: 1.503307, test accuracy: 0.186667
epoch: 4, accuracy: 0.280000, avg. loss: 1.475580, test accuracy: 0.200000
epoch: 5, accuracy: 0.360000, avg. loss: 1.506248, test accuracy: 0.200000
epoch: 6, accuracy: 0.400000, avg. loss: 1.512616, test accuracy: 0.240000
epoch: 7, accuracy: 0.280000, avg. loss: 1.485583, test accuracy: 0.226667
epoch: 8, accuracy: 0.400000, avg. loss: 1.451588, test accuracy: 0.226667
epoch: 9, accuracy: 0.360000, avg. loss: 1.451431, test accuracy: 0.226667
epoch: 10, accuracy: 0.320000, avg. loss: 1.454059, test accuracy: 0.240000
epoch: 11, accuracy: 0.280000, avg. loss: 1.491456, test accuracy: 0.226667
epoch: 12, accuracy: 0.440000, avg. loss: 1.470896, test accuracy: 0.240

2024-01-07 05:26:51,016 - INFO - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth)


{'train': 25, 'test': 75}
epoch: 0, accuracy: 0.120000, avg. loss: 1.753704, test accuracy: 0.093333
epoch: 1, accuracy: 0.120000, avg. loss: 1.728337, test accuracy: 0.133333
epoch: 2, accuracy: 0.040000, avg. loss: 1.722818, test accuracy: 0.160000
epoch: 3, accuracy: 0.160000, avg. loss: 1.689935, test accuracy: 0.173333
epoch: 4, accuracy: 0.120000, avg. loss: 1.645311, test accuracy: 0.160000
epoch: 5, accuracy: 0.120000, avg. loss: 1.719704, test accuracy: 0.160000
epoch: 6, accuracy: 0.040000, avg. loss: 1.694996, test accuracy: 0.146667
epoch: 7, accuracy: 0.080000, avg. loss: 1.699202, test accuracy: 0.120000
epoch: 8, accuracy: 0.160000, avg. loss: 1.707182, test accuracy: 0.120000
epoch: 9, accuracy: 0.160000, avg. loss: 1.648751, test accuracy: 0.106667
epoch: 10, accuracy: 0.160000, avg. loss: 1.664260, test accuracy: 0.093333
epoch: 11, accuracy: 0.120000, avg. loss: 1.632955, test accuracy: 0.120000
epoch: 12, accuracy: 0.080000, avg. loss: 1.670627, test accuracy: 0.106

2024-01-07 05:29:58,593 - INFO - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth)


{'train': 25, 'test': 75}
epoch: 0, accuracy: 0.280000, avg. loss: 1.647789, test accuracy: 0.106667
epoch: 1, accuracy: 0.280000, avg. loss: 1.616148, test accuracy: 0.093333
epoch: 2, accuracy: 0.320000, avg. loss: 1.628009, test accuracy: 0.093333
epoch: 3, accuracy: 0.200000, avg. loss: 1.624410, test accuracy: 0.120000
epoch: 4, accuracy: 0.280000, avg. loss: 1.580313, test accuracy: 0.186667
epoch: 5, accuracy: 0.280000, avg. loss: 1.594694, test accuracy: 0.186667
epoch: 6, accuracy: 0.280000, avg. loss: 1.559619, test accuracy: 0.173333
epoch: 7, accuracy: 0.360000, avg. loss: 1.556850, test accuracy: 0.173333
epoch: 8, accuracy: 0.400000, avg. loss: 1.528625, test accuracy: 0.200000
epoch: 9, accuracy: 0.440000, avg. loss: 1.524285, test accuracy: 0.213333
epoch: 10, accuracy: 0.480000, avg. loss: 1.499158, test accuracy: 0.226667
epoch: 11, accuracy: 0.240000, avg. loss: 1.565631, test accuracy: 0.226667
epoch: 12, accuracy: 0.320000, avg. loss: 1.521159, test accuracy: 0.253

In [9]:
for good_set in good_classes:
    print(bad_classes - good_set)
print(bad_classes)

{'Forest', 'AnnualCrop', 'PermanentCrop', 'Pasture', 'Residential', 'SeaLake', 'River', 'Highway', 'HerbaceousVegetation'}
{'Forest', 'AnnualCrop', 'PermanentCrop', 'Pasture', 'Residential', 'SeaLake', 'River', 'Highway', 'HerbaceousVegetation'}
{'Forest', 'AnnualCrop', 'PermanentCrop', 'Pasture', 'Residential', 'SeaLake', 'River', 'Highway', 'HerbaceousVegetation'}
{'Forest', 'AnnualCrop', 'PermanentCrop', 'Pasture', 'Residential', 'SeaLake', 'River', 'Highway', 'HerbaceousVegetation'}
{'Forest', 'AnnualCrop', 'PermanentCrop', 'Pasture', 'Residential', 'SeaLake', 'River', 'Highway', 'HerbaceousVegetation'}
{'Forest', 'AnnualCrop', 'PermanentCrop', 'Pasture', 'Residential', 'SeaLake', 'River', 'Highway', 'HerbaceousVegetation'}
