In [None]:
import json
import os
import torch
import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd
import torch.nn as nn
from torchsummary import summary
from utils import  format_for_display 
from DataLoader import EuroSAT , UC_MERCED , custom_collate_fn
from engine import train_one_epoch , test_one_epoch
from torchvision import  transforms
from torchvision.transforms import ToTensor


In [None]:
# Hyperparameters
train_transform = transforms.Compose([
    transforms.Resize((256,256)) , 
    transforms.RandomHorizontalFlip(p=0.5),  # Flip half of the images
    transforms.RandomRotation(degrees=15),  # Rotate images between -15 and +15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   
])
test_transform = transforms.Compose([
    transforms.Resize((256,256)) , 
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   
])
BATCH_SIZE = 4 
LR = 0.0001
Epochs = 50
device = 'cuda' if torch.cuda.is_available() else "cpu"
device


Dataset Prep --> **EUROSAT**

In [None]:
train_csv = pd.read_csv("../Datasets/EuroSAT/EuroSAT/train.csv" , index_col = 0)
test_csv = pd.read_csv("../Datasets/EuroSAT/EuroSAT/test.csv" , index_col = 0)
val_csv = pd.read_csv("../Datasets/EuroSAT/EuroSAT/validation.csv" , index_col = 0)

train_csv = train_csv.sort_values(axis = 0 , by = ['ClassName'])
test_csv = test_csv.sort_values(axis = 0 , by = ['ClassName'])
val_csv = val_csv.sort_values(axis = 0 , by = ['ClassName'])

In [None]:
with open("../Datasets/EuroSAT/EuroSAT/label_map.json" , 'r') as file:
    labels = json.load(file)
    class_names = list(labels.keys())
class_names 

In [None]:
train_set = []
test_set = []
val_set = []
sets = [train_csv , val_csv , test_csv]

for i , set in enumerate(sets):
    if i == 0:
        for index, row in set.iterrows():
            train_set.append(list(row))
    elif i == 2:
        for index, row in set.iterrows():
            test_set.append(list(row))
            
    else: 
        for index, row in set.iterrows():
            val_set.append(list(row))           

In [None]:
train = EuroSAT(parent_dir = "../Datasets/EuroSAT/EuroSAT/" , data = train_set , transform = transform)
val = EuroSAT(parent_dir = "../Datasets/EuroSAT/EuroSAT/" , data = val_set , transform = transform)
test = EuroSAT(parent_dir = "../Datasets/EuroSAT/EuroSAT/" , data = test_set , transform = transform)

len(train) , len(val) , len(test)

In [None]:
train_loader = torch.utils.data.DataLoader(train , shuffle = True , batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val , shuffle = True , batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test, shuffle = True , batch_size=BATCH_SIZE)

len(train_loader) , len(val_loader)  , len(test_loader)

Dataset Prep **UC-MERCED**

In [None]:
class_to_idx = {
    'Agricultural': 0,
    'Airplane' : 1 ,
    'Baseball diamond' : 2,
    'Beach' : 3,
    'Buildings' : 4,
    'Chaparral' : 5,
    'Dense residential' : 6,
    'Forest' : 7 , 
    'Freeway' : 8,
    'Golf course' : 9,
    'Harbor' : 10,
    'Intersection' : 11,
    'Medium residential' : 12,
    'Mobile home park' : 13,
    'Overpass' : 14,
    'Parking lot' : 15,
    'River' : 16,
    'Runway' : 17,
    'Sparse residential' : 18,
    'Storage tanks' : 19,
    'Tennis court' : 20
    }
class_names = ['Agricultural', 'Airplane', 'Baseball diamond', 'Beach', 'Buildings', 'Chaparral', 'Dense residential', 'Forest', 'Freeway',
               'Golf course', 'Harbor', 'Intersection', 'Medium residential', 'Mobile home park', 'Overpass', 'Parking lot', 'River',
               'Runway', 'Sparse residential', 'Storage tanks', 'Tennis court']

In [None]:
train = UC_MERCED(parent_dir = "../Datasets/UC-MERCED/UCMerced_LandUse/Images/train" , transform = train_transform)
val = UC_MERCED(parent_dir = "../Datasets/UC-MERCED/UCMerced_LandUse/Images/val" , transform = train_transform)
test = UC_MERCED(parent_dir = "../Datasets/UC-MERCED/UCMerced_LandUse/Images/test" , transform = test_transform)
print(len(train) , len(test) , len(val))

train_loader = torch.utils.data.DataLoader(train , shuffle = True , batch_size=BATCH_SIZE , collate_fn=custom_collate_fn )
val_loader = torch.utils.data.DataLoader(val , shuffle = True , batch_size=BATCH_SIZE , collate_fn=custom_collate_fn)
test_loader = torch.utils.data.DataLoader(test, shuffle = False , batch_size=8 , collate_fn=custom_collate_fn)
print(len(train_loader) , len(test_loader) , len(val_loader))

Visualize some samples

In [None]:
train_iter = iter(train_loader)
batch = 7
for batch_no in range(batch):
    first_batch = next(train_iter)
    images , labels  = first_batch

images.shape , labels


In [None]:
formatted_images = format_for_display(images)
    
images[0].shape , formatted_images[0].shape

In [None]:
nrows = 2
ncolumns = 4
fig, axs = plt.subplots(nrows, ncolumns, figsize=(15, 6))

# Flatten the axs array to simplify accessing individual subplots
axs = axs.flatten()

for i in range(len(images)):
    ax = axs[i]  # Access the individual subplot
    ax.imshow(formatted_images[i])  # Display the image
    ax.set_title(class_names[labels[i]])  # Set the title to the class name of the image
    ax.axis('off')  # Hide the axis

plt.show()

NN using Pytorch library

In [None]:
pwd

In [None]:
from Model import NN_1 , NN_2 , NN_3 , NN_4

model = NN_4().to(device)
optimizer = torch.optim.Adam(model.parameters() , lr = LR)
cross_entropy = torch.nn.CrossEntropyLoss()
scheduler = "exponential"
if scheduler == "onecyclelr":
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, steps_per_epoch=len(train_loader), epochs=Epochs, pct_start=0.2)
elif scheduler == "multi_step_lr":
    lr_drop_list = [4, 8]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_drop_list)
elif scheduler == "step_lr":
    step_size = 10
    gamma = 0.5
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = step_size , gamma = gamma)
else:
    gamma = 0.98
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer , gamma)
summary(model , input_size=( 3 , 256 , 256))

In [None]:
out_dir = 'weights/best_checkpoint.pth'
train_loss , val_loss , current_lr = train_one_epoch(model ,
                                     training_loader=train_loader ,
                                     validation_loader = val_loader ,
                                     optimizer=optimizer ,
                                     lr_scheduler = lr_scheduler , 
                                     epochs = Epochs , 
                                     loss_func = cross_entropy ,
                                     lora = False , 
                                     device = device ,
                                     out_dir = out_dir ,
                                     resume = False) 

In [None]:
checkpoint = torch.load("weights/best_checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model.eval()
model.to(device)
test_loss = test_one_epoch(model = model , test_loader= test_loader , loss_func=cross_entropy , device = device)

In [None]:
test_iter = iter(test_loader)
batch_no = 8
for batch in range(batch_no):
    first_batch = next(test_iter)
    images , labels  = first_batch
    images = images.to(device)
with torch.no_grad():
    logits = model(images).to('cpu')
logits = logits.numpy()
pred_classes = np.argmax(logits, axis=1)
pred_classes , labels

In [None]:
images = images.to('cpu')
formatted_images  = format_for_display(images)
nrows = 2
ncolumns = 4
fig, axs = plt.subplots(nrows, ncolumns, figsize=(15, 6))

axs = axs.flatten()
for i in range(len(formatted_images)):
    ax = axs[i]  # Access the individual subplot
    ax.imshow(formatted_images[i]) 
    title = f'Predicted: {class_names[pred_classes[i]]}\nGT: {class_names[labels[i].item()]}'
    ax.set_title(title)
    ax.axis('off') 


plt.show()
