In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.metrics import f1_score, confusion_matrix
from torchmetrics.classification import MulticlassF1Score
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from PIL import Image
from tqdm import tqdm

In [None]:
IMG_SIZE = 128
BATCH_SIZE = 4
EPOCHS = 10
PLOT_DIR = "C:/Users/Namya/JobScreenTask/Project_root/venv_313/earth-observation-task/Scripts/plots"
os.makedirs(PLOT_DIR, exist_ok=True)
DEVICE = torch.device("cpu")

In [None]:
train_df = pd.read_csv("../train.csv")
test_df = pd.read_csv("../test.csv")

# Map labels to integers
labels = sorted(train_df['label'].unique())
label2idx = {label: idx for idx, label in enumerate(labels)}
idx2label = {idx: label for label, idx in label2idx.items()}
train_df['label_idx'] = train_df['label'].map(label2idx)
test_df['label_idx'] = test_df['label'].map(label2idx)

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

In [None]:
class CachedDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        self.cache = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Caching images"):
            img = Image.open(row['file']).convert("RGB")
            self.cache.append(img)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img = self.cache[idx]
        if self.transform:
            img = self.transform(img)
        label = self.df.iloc[idx]['label_idx']
        return img, label

In [None]:
class SatelliteDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['file']).convert("RGB")
        img = transform(img)   # use the defined transform
        label = row['label_idx']
        return img, label

In [None]:
train_loader = DataLoader(
    CachedDataset(train_df, transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)
test_loader = DataLoader(
    CachedDataset(test_df, transform=transform),
    batch_size=1,
    num_workers=0
)

model = models.resnet18(pretrained=False, num_classes=len(labels))
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm

dataset = SatelliteDataset(train_df, transform=transform)
for i in tqdm(range(len(dataset))):
    img, label = dataset[i]

In [None]:
# Training loop
print("\n Starting training…\n")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}")
print(f" Epoch [{epoch+1}/{EPOCHS}] Avg Loss: {total_loss/len(train_loader):.4f}")
print("\n Training complete. Starting evaluation…\n")

In [None]:
# Evaluation
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for imgs, targets in test_loader:
        imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
        outputs = model(imgs)
        preds = torch.argmax(outputs, dim=1)
        y_true.append(targets.item())
        y_pred.append(preds.item())

In [None]:
# Custom F1
f1_custom = f1_score(y_true, y_pred, average='weighted')
print(f"Custom F1 score (sklearn): {f1_custom:.4f}")

# Torchmetrics F1
tm_f1 = MulticlassF1Score(num_classes=len(labels), average='weighted').to(DEVICE)
y_true_tensor = torch.tensor(y_true).to(DEVICE)
y_pred_tensor = torch.tensor(y_pred).to(DEVICE)
f1_torchmetrics = tm_f1(y_pred_tensor, y_true_tensor).item()
print(f"Torchmetrics F1 score: {f1_torchmetrics:.4f}")

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(os.path.join(PLOT_DIR, "confusion_matrix.png"))
plt.show()
plt.close()

print(f"Confusion matrix plot saved: {PLOT_DIR}/confusion_matrix.png")

In [None]:
# 5 correct & 5 incorrect
correct = [i for i, (p,t) in enumerate(zip(y_pred,y_true)) if p==t][:5]
incorrect = [i for i, (p,t) in enumerate(zip(y_pred,y_true)) if p!=t][:5]
test_df = test_df.reset_index()

fig, axes = plt.subplots(2,5, figsize=(20,8))
for ax, idx in zip(axes[0], correct):
    img = plt.imread(test_df.loc[idx, 'file'])
    ax.imshow(img)
    ax.set_title(f"True: {test_df.loc[idx,'label']}\nPred: {test_df.loc[idx,'label']}")
    ax.axis('off')
for ax, idx in zip(axes[1], incorrect):
    img = plt.imread(test_df.loc[idx, 'file'])
    pred_lbl = idx2label[y_pred[idx]]
    ax.imshow(img)
    ax.set_title(f"True: {test_df.loc[idx,'label']}\nPred: {pred_lbl}")
    ax.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(PLOT_DIR, "correct_incorrect_examples.png"))
plt.show()
plt.close()

print(f"Correct & incorrect examples plot saved: {PLOT_DIR}/correct_incorrect_examples.png")