In [None]:
!pip install torch torchvision timm

# SWIN for GeoLocation Mapping

In [None]:
import torch
import torchvision
import timm
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!ls /content/drive/MyDrive/proj
!ls /content/drive/MyDrive/proj_test/test_data/

# Helpers for saving and loading model training progress

In [None]:
import os
import torch

# Paths for saving/loading
model_save_path = '/content/drive/MyDrive/swin_model.pth'
optimizer_save_path = '/content/drive/MyDrive/swin_optimizer.pth'

def save_model(model, optimizer, epoch):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, model_save_path)
    print(f"Model saved after epoch {epoch}")

def load_model(model, optimizer):
    """Load the model and optimizer states."""
    if os.path.exists(model_save_path):
        checkpoint = torch.load(model_save_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Model loaded. Resuming from epoch {start_epoch}")
        return start_epoch
    else:
        print("No saved model found. Starting from scratch.")
        return 1

# Data Preparation

In [None]:
# Reshape images to 224x224 and normalize
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Define directories for train and load them
data_dir_train = "/content/drive/MyDrive/proj"
train_dataset = datasets.ImageFolder(root=data_dir_train, transform=transform, is_valid_file=lambda x: 'info.txt' not in x)

In [None]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)

# SWIN Model Setup

In [None]:
model = timm.create_model('swin_large_patch4_window7_224', pretrained=True, num_classes=len(train_dataset.classes))
device = torch.device("cuda")
model = model.to(device)

# Loss Function and Optimizer

In [None]:
loss_fn = nn.CrossEntropyLoss()
adam_optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop

In [None]:
epochs = 10
for epoch in range(1, 11):
  ##  Hack Logic in case Google Colab Timedout
  # transform = transforms.Compose([
  #     transforms.Resize((224, 224)),
  #     transforms.ToTensor(),
  #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  # ])
  # data_dir_train = "/content/drive/MyDrive/proj"
  # train_dataset = datasets.ImageFolder(root=data_dir_train, transform=transform, is_valid_file=lambda x: 'info.txt' not in x)
  # train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=12)

  # # Reload the model
  # model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=len(train_dataset.classes))
  # device = torch.device("cuda")
  # model = model.to(device)
  # criterion = nn.CrossEntropyLoss()
  # optimizer = optim.Adam(model.parameters(), lr=1e-4)

  # load_model(model, optimizer)

  model.train()
  running_loss = 0.0
  i = 0
  for inputs, labels in train_loader:
      i += 1
      if i % 10 == 0:
          print(i)
      inputs, labels = inputs.to(device), labels.to(device)
      adam_optimizer.zero_grad()
      outputs = model(inputs)
      loss = loss_fn(outputs, labels)
      loss.backward()
      adam_optimizer.step()
      running_loss += loss.item()
  print(f"Epoch [{epoch}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

  save_model(model, adam_optimizer, epoch)
print("Training Complete")

# Evaluation Loop

In [None]:
from collections import defaultdict

# Load Test Data
data_dir_test = "/content/drive/MyDrive/proj_test/test_data/"
test_dataset = datasets.ImageFolder(root=data_dir_test, transform=transform, is_valid_file=lambda x: 'info.txt' not in x)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=12)

model.eval()
correct_top1 = 0
correct_top5 = 0
total = 0
print(len(test_loader))
i = 0

# Dictionaries to keep track of counts for each state
correct_top1_per_label = defaultdict(int)
correct_top5_per_label = defaultdict(int)
total_count_per_label = defaultdict(int)

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        if i % 10 == 0:
          print(i)
        i += 1

        outputs = model(inputs)
        total += labels.size(0)
        unique_labels, counts = labels.unique(return_counts=True)
        for label, count in zip(unique_labels, counts):
            total_count_per_label[label.item()] += count.item()

        # Top-1 accuracy
        _, predicted = torch.max(outputs, 1)  # Get the top-1 prediction for each image
        correct_top1 += (predicted == labels).sum().item()

        # Update per-label correct top-1 counts (corrected to increment by 1 per correct label)
        correct_labels_mask = (predicted == labels)
        correct_labels = labels[correct_labels_mask]
        for label in correct_labels:
            correct_top1_per_label[label.item()] += 1
            print(label.item())

        # Top-5 accuracy
        _, top5_pred = torch.topk(outputs, 5, dim=1)  # Get the top-5 predictions for each image
        for idx, label in enumerate(labels):
            if label in top5_pred[idx]:
                correct_top5 += 1
                correct_top5_per_label[label.item()] += 1

# Print overall accuracies
print(f'Top-1 Accuracy: {100 * correct_top1 / total:.2f}%')
print(f' Top-5 Accuracy: {100 * correct_top5 / total:.2f}%')

# Calculate and print per-state accuracies
class_names = train_dataset.classes  # Get class names (state names) from the dataset

print("\nPer-State Top-1 and Top-5 Accuracies:")
for label, state_name in enumerate(class_names):
    total_per_label = total_count_per_label[label]
    top1_accuracy = (correct_top1_per_label[label] / total_per_label) * 100
    top5_accuracy = (correct_top5_per_label[label] / total_per_label) * 100
    print(f"{state_name}: Top-1 Accuracy: {top1_accuracy:.2f}%, Top-5 Accuracy: {top5_accuracy:.2f}%")