In [None]:
## import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
from collections import Counter
import matplotlib.pyplot as plt

# ------------------------------
# Setup Device
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------------
# Data Transformations (Augmentation + Normalization)
# ------------------------------
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# ------------------------------
# Data Loading
# ------------------------------
train_dir = os.path.join('state_dataset', 'train')
val_dir = os.path.join('state_dataset', 'valid')

train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(root=val_dir, transform=val_transforms)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# ------------------------------
# Model Setup (Pretrained ResNet18 + Dropout + BatchNorm)
# ------------------------------
from torchvision.models import resnet34, ResNet34_Weights
model = resnet34(weights=ResNet34_Weights.DEFAULT)

num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, 11)
)

model = model.to(device)

# Unfreeze the entire ResNet so all layers are trainable
for param in model.parameters():
    param.requires_grad = True

# ------------------------------
# Loss Function, Optimizer, Scheduler
# ------------------------------
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Compute class weights
class_counts = Counter(train_dataset.targets)
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_dataset.targets),
    y=np.array(train_dataset.targets)
)
weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# Apply weights to the loss
criterion = nn.CrossEntropyLoss(weight=weights_tensor, label_smoothing=0.1)

optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.3, patience=3, verbose=True
)

# ------------------------------
# Early Stopping Parameters
# ------------------------------
patience = 5
best_val_loss = float('inf')
epochs_without_improvement = 0

# ------------------------------
# Metric History for Plotting
# ------------------------------
epochs_history = []
train_loss_history = []
val_loss_history = []
train_accuracy_history = []
val_accuracy_history = []
train_f1_history = []
val_f1_history = []

# ------------------------------
# Training Loop
# ------------------------------
num_epochs = 40
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_train_preds = []
    all_train_labels = []
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        all_train_preds.extend(preds.cpu().numpy())
        all_train_labels.extend(labels.cpu().numpy())
    
    train_loss = running_loss / len(train_dataset)
    train_accuracy = accuracy_score(all_train_labels, all_train_preds)
    train_f1 = f1_score(all_train_labels, all_train_preds, average='weighted')
    
    model.eval()
    val_running_loss = 0.0
    all_val_preds = []
    all_val_labels = []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())
    
    val_loss = val_running_loss / len(val_dataset)
    val_accuracy = accuracy_score(all_val_labels, all_val_preds)
    val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')
    
    epochs_history.append(epoch + 1)
    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)
    train_accuracy_history.append(train_accuracy)
    val_accuracy_history.append(val_accuracy)
    train_f1_history.append(train_f1)
    val_f1_history.append(val_f1)
    
    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {train_loss:.4f} | Accuracy: {train_accuracy:.4f} | F1-score: {train_f1:.4f}")
    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Val Loss: {val_loss:.4f} | Accuracy: {val_accuracy:.4f} | F1-score: {val_f1:.4f}\n")
    
    scheduler.step(val_loss)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

# ------------------------------
# Plotting Metrics
# ------------------------------
fig, axes = plt.subplots(3, 1, figsize=(10, 15))
plt.subplots_adjust(hspace=0.4)

axes[0].plot(epochs_history, train_loss_history, marker='o', label="Train Loss")
axes[0].plot(epochs_history, val_loss_history, marker='o', label="Validation Loss")
axes[0].set_title("Loss per Epoch")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend()
axes[0].grid(True)

axes[1].plot(epochs_history, train_accuracy_history, marker='o', label="Train Accuracy")
axes[1].plot(epochs_history, val_accuracy_history, marker='o', label="Validation Accuracy")
axes[1].set_title("Accuracy per Epoch")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].legend()
axes[1].grid(True)

axes[2].plot(epochs_history, train_f1_history, marker='o', label="Train F1 Score")
axes[2].plot(epochs_history, val_f1_history, marker='o', label="Validation F1 Score")
axes[2].set_title("F1 Score per Epoch")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("F1 Score")
axes[2].legend()
axes[2].grid(True)

plt.show()

# ------------------------------
# Final Evaluation of Best Model
# ------------------------------
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

final_preds = []
final_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        final_preds.extend(preds.cpu().numpy())
        final_labels.extend(labels.cpu().numpy())

final_accuracy = accuracy_score(final_labels, final_preds)
final_f1 = f1_score(final_labels, final_preds, average='weighted')

print("\n Final Evaluation of Best Model:")
print(f"Final Validation Accuracy: {final_accuracy:.4f}")
print(f"Final Validation F1 Score: {final_f1:.4f}")
