In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
from torch.utils.data import DataLoader
import os
from sklearn.cluster import KMeans
import numpy as np
from tqdm import tqdm
import pandas as pd

class AngleDataset(Dataset):
    def __init__(self, image_dir, labels_df=None, transform=None):
        self.image_dir = image_dir
        self.labels_df = labels_df
        self.transform = transform
        self.filenames = labels_df['filename'] if labels_df is not None else [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        img_path = os.path.join(self.image_dir, row['filename'])
        
        label = int(row['Region_ID'])
            
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
            
        return image, label, self.filenames[idx] if isinstance(self.filenames, list) else row['filename']

class Encoder(nn.Module):
    def __init__(self, latent_dim=128, pretrained=True):
        super().__init__()

        resnet = models.resnet50(weights='DEFAULT' if pretrained else None)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        num_features = resnet.fc.in_features
        
        self.fc = nn.Linear(num_features, latent_dim)
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self, latent_dim=128, output_channels=3):
        super().__init__()
        
        # Simple decoder using transposed convolutions
        self.fc = nn.Linear(latent_dim, 512*4*4)  # Expand to initial feature map
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 8x8
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 128x128
            nn.BatchNorm2d(16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(16, output_channels, kernel_size=4, stride=2, padding=1),  # 256x256
            nn.Sigmoid()  # Scale output to [0, 1]
        )
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 512, 4, 4)
        x = self.decoder(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, latent_dim=128, pretrained=True):
        super().__init__()
        self.encoder = Encoder(latent_dim, pretrained)
        self.decoder = Decoder(latent_dim)
        self.latent_dim = latent_dim
        
    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed, latent
    
    def encode(self, x):
        return self.encoder(x)

def train_autoencoder(model, train_loader, val_loader, num_epochs, device, learning_rate=1e-4):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):

        model.train()
        train_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        for images, _, _ in pbar:
            images = images.to(device)
            
            reconstructed, _ = model(images)
            loss = criterion(reconstructed, images)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            pbar.set_postfix(loss=loss.item())
            
        avg_train_loss = train_loss / len(train_loader.dataset)
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
            for images, _, _ in pbar:
                images = images.to(device)
                
                reconstructed, _ = model(images)
                loss = criterion(reconstructed, images)
                
                val_loss += loss.item() * images.size(0)
                pbar.set_postfix(loss=loss.item())
                
        avg_val_loss = val_loss / len(val_loader.dataset)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_autoencoder.pth')
            
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
    
    return model

def extract_features(model, data_loader, device):
    model.eval()
    features = []
    labels = []
    filenames = []
    
    with torch.no_grad():
        for images, batch_labels, batch_filenames in tqdm(data_loader, desc="Extracting features"):
            images = images.to(device)
            latent_features = model.encode(images)
            
            features.append(latent_features.cpu().numpy())
            labels.append(batch_labels.numpy())
            filenames.extend(batch_filenames)
    
    features = np.vstack(features)
    labels = np.concatenate(labels)
    
    return features, labels, filenames

def cluster_features(features, n_clusters=15):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=10)
    cluster_labels = kmeans.fit_predict(features)
    return kmeans, cluster_labels

def map_clusters_to_regions(cluster_labels, true_labels, n_clusters=15, n_regions=15):

    mapping = {}
    for cluster in range(n_clusters):

        cluster_indices = np.where(cluster_labels == cluster)[0]
        if len(cluster_indices) == 0:
            continue
            

        cluster_true_labels = true_labels[cluster_indices]
        

        unique_labels, counts = np.unique(cluster_true_labels, return_counts=True)
        
        if len(unique_labels) == 1 and unique_labels[0] == -1:
            continue
            
        valid_indices = unique_labels != -1
        valid_unique = unique_labels[valid_indices]
        valid_counts = counts[valid_indices]
        
        if len(valid_unique) > 0:

            most_frequent_region = valid_unique[np.argmax(valid_counts)]
            mapping[cluster] = most_frequent_region
    
    return mapping

def evaluate_clustering(cluster_labels, true_labels, cluster_to_region_map, n_regions=15):

    valid_indices = true_labels != -1
    
    if np.sum(valid_indices) == 0:
        return {
            'accuracy': 0,
            'adjacent_accuracy': 0,
            'confusion_matrix': np.zeros((n_regions, n_regions))
        }
    
    valid_clusters = cluster_labels[valid_indices]
    valid_true = true_labels[valid_indices]
    
    predicted_regions = np.array([cluster_to_region_map.get(c, -1) for c in valid_clusters])
    
    eval_indices = predicted_regions != -1
    pred_regions = predicted_regions[eval_indices]
    true_regions = valid_true[eval_indices]
    
    if len(pred_regions) == 0:
        return {
            'accuracy': 0,
            'adjacent_accuracy': 0,
            'confusion_matrix': np.zeros((n_regions, n_regions))
        }
    
    correct = np.sum(pred_regions == true_regions)
    accuracy = 100 * correct / len(pred_regions)
    
    adjacent_correct = 0
    for i in range(len(pred_regions)):
        pred = pred_regions[i]
        true = true_regions[i]

        if pred == true or (pred == (true + 1) % n_regions) or (pred == (true - 1) % n_regions):
            adjacent_correct += 1
    
    adjacent_accuracy = 100 * adjacent_correct / len(pred_regions)
    
    confusion_matrix = np.zeros((n_regions, n_regions))
    for t, p in zip(true_regions, pred_regions):
        confusion_matrix[t, p] += 1
    
    return {
        'accuracy': accuracy,
        'adjacent_accuracy': adjacent_accuracy,
        'confusion_matrix': confusion_matrix
    }

def autoencoder_clustering_pipeline(train_loader, val_loader, test_loader, latent_dim=128, 
                                   n_clusters=15, n_regions=15, num_epochs=10, device="cuda"):

    print("Training autoencoder...")
    model = Autoencoder(latent_dim=latent_dim)
    trained_model = train_autoencoder(model, train_loader, val_loader, num_epochs, device)
    
    print("Extracting features from training data...")
    train_features, train_labels, _ = extract_features(trained_model, train_loader, device)
    

    print("Clustering features...")
    kmeans, cluster_labels = cluster_features(train_features, n_clusters=n_clusters)
    
    print("Mapping clusters to regions...")
    cluster_to_region_map = map_clusters_to_regions(cluster_labels, train_labels, 
                                                   n_clusters=n_clusters, n_regions=n_regions)
    
    print("Evaluating on test set...")
    test_features, test_labels, test_filenames = extract_features(trained_model, test_loader, device)
    test_cluster_labels = kmeans.predict(test_features)
    
    results = evaluate_clustering(test_cluster_labels, test_labels, cluster_to_region_map, n_regions=n_regions)
    
    torch.save({
        'model_state_dict': trained_model.state_dict(),
        'kmeans': kmeans,
        'cluster_to_region_map': cluster_to_region_map,
        'latent_dim': latent_dim,
        'n_clusters': n_clusters
    }, 'autoencoder_clustering_model.pth')
    
    return trained_model, kmeans, cluster_to_region_map, results

def predict_region(model, kmeans, cluster_to_region_map, image_tensor, device):
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)
        latent = model.encode(image_tensor)
        latent_np = latent.cpu().numpy()
        
        cluster = kmeans.predict(latent_np)[0]
        
        region = cluster_to_region_map.get(cluster, -1)
        if region != -1:
            region = region + 1
        
    return region

transform_base = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ,
    T.RandomErasing(p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
])

transform_val = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image_dir_train = "Dataset/Train/images_train"
labels_path_train = "Dataset/Train/labels_train.csv"

labels_df = pd.read_csv(labels_path_train)

train_dataset = AngleDataset(image_dir_train, labels_df,transform=transform_base)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)


images_dir_val = "Dataset/Val/images_val"
labels_path_val = "Dataset/Val/labels_val.csv"
labels_df_val = pd.read_csv(labels_path_val)

val_dataset = AngleDataset(images_dir_val, labels_df_val, transform_val)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, kmeans, cluster_map, results = autoencoder_clustering_pipeline(
    train_loader, val_loader, val_loader, 
    latent_dim=128, n_clusters=15, n_regions=15,
    num_epochs=10, device=device
)

print(f"Test accuracy: {results['accuracy']:.2f}%")
print(f"Adjacent accuracy: {results['adjacent_accuracy']:.2f}%")