In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim.lr_scheduler import OneCycleLR
from timm.models import efficientnet_b2
from PIL import Image, UnidentifiedImageError
import re
import io
import json

GPU_INDEX = 0

CONFIG = {
    'data_root': 'ABC(RSCD)/Balanced_RSCD_Dataset',
    'batch_size': 32,
    'num_epochs': 20,
    'learning_rate': 2e-4,
    'weight_decay': 1e-5,
    'num_workers': 8,
    'device': torch.device(f'cuda:{GPU_INDEX}' if torch.cuda.is_available() else 'cpu'),
    'log_interval': 50,
    'resume_checkpoint': None
}

CONFIG['checkpoint_dir'] = os.path.join(CONFIG['data_root'], 'checkpoints2')
CONFIG['plots_dir'] = os.path.join(CONFIG['data_root'], 'plots')

os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(CONFIG['plots_dir'], exist_ok=True)

class RoadSurfaceDataset(Dataset):
    def __init__(self, root_dir, split, transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        self.corrupted_images = 0
        
        def extract_label(filename):
            parts = filename.split('-')
            for i, part in enumerate(parts):
                if part[0].isdigit():
                    label = '-'.join(parts[i+1:]).split('.')[0]
                    return label.replace('-', '_')
            return None

        train_dir = os.path.join(root_dir, 'train')
        class_folders = sorted([d for d in os.listdir(train_dir) 
                        if os.path.isdir(os.path.join(train_dir, d))])
        
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(class_folders)}
        self.idx_to_class = {i: cls_name for i, cls_name in enumerate(class_folders)}
        
        if split == 'train':
            for class_name in class_folders:
                class_dir = os.path.join(train_dir, class_name)
                class_idx = self.class_to_idx[class_name]
                
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(class_dir, img_name)
                        self.samples.append((img_path, class_idx))
        else:
            split_dir = os.path.join(root_dir, split)
            
            for img_name in os.listdir(split_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(split_dir, img_name)
                    
                    extracted_label = extract_label(img_name)
                    
                    if extracted_label in self.class_to_idx:
                        class_idx = self.class_to_idx[extracted_label]
                        self.samples.append((img_path, class_idx))
                    else:
                        print(f"Warning: Unknown label '{extracted_label}' in {img_name}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        while True:
            try:
                img_path, class_idx = self.samples[idx]
                
                image = Image.open(img_path).convert('RGB')
                
                if self.transform:
                    image = self.transform(image)
                
                return image, class_idx
            
            except (UnidentifiedImageError, IOError) as e:
                print(f"Corrupted image found: {img_path}")
                self.corrupted_images += 1
                
                idx = (idx + 1) % len(self.samples)
    
    def get_corrupted_image_count(self):
        return self.corrupted_images

def find_latest_checkpoint(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('_checkpoint.pth')]
    if not checkpoints:
        return None
    
    latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split('_')[1]))[-1]
    return os.path.join(checkpoint_dir, latest_checkpoint)

def verify_dataset_structure():
    train_dir = os.path.join(CONFIG['data_root'], 'train')
    val_dir = os.path.join(CONFIG['data_root'], 'vali_20k')
    test_dir = os.path.join(CONFIG['data_root'], 'test_50k')
    
    for dir_path, dir_name in [(train_dir, 'train'), (val_dir, 'vali_20k'), (test_dir, 'test_50k')]:
        if not os.path.exists(dir_path):
            print(f"ERROR: Directory not found: {dir_path}")
            return False
        else:
            print(f"Found directory: {dir_name}")
    
    class_folders = sorted([d for d in os.listdir(train_dir) 
                     if os.path.isdir(os.path.join(train_dir, d))])
    
    if not class_folders:
        print(f"ERROR: No class folders found in {train_dir}")
        return False
    
    print(f"Found {len(class_folders)} classes in the training directory:")
    for i, class_name in enumerate(class_folders):
        class_path = os.path.join(train_dir, class_name)
        num_samples = len([f for f in os.listdir(class_path) 
                          if os.path.isfile(os.path.join(class_path, f)) and
                          f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"  {i+1}. {class_name}: {num_samples} samples")
    
    for dir_path, dir_name in [(val_dir, 'vali_20k'), (test_dir, 'test_50k')]:
        num_samples = len([f for f in os.listdir(dir_path) 
                         if os.path.isfile(os.path.join(dir_path, f)) and 
                         f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"{dir_name}: {num_samples} samples")
    
    return True

def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2)
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

def get_datasets_and_loaders():
    train_transform, val_transform = get_transforms()
    
    train_dataset = RoadSurfaceDataset(
        root_dir=CONFIG['data_root'],
        split='train',
        transform=train_transform
    )
    
    val_dataset = RoadSurfaceDataset(
        root_dir=CONFIG['data_root'],
        split='vali_20k',
        transform=val_transform
    )
    
    test_dataset = RoadSurfaceDataset(
        root_dir=CONFIG['data_root'],
        split='test_50k',
        transform=val_transform
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader, train_dataset.idx_to_class

def create_model(num_classes):
    model = efficientnet_b2(pretrained=True)
    
    in_features = model.classifier.in_features
    
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, num_classes)
    )
    
    return model

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.classes = classes
        self.log_softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, pred, target):
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        
        return torch.mean(torch.sum(-true_dist * self.log_softmax(pred), dim=1))

class TrainingHistory:
    def __init__(self, save_dir):
        self.save_dir = save_dir
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        self.lr_history = []
        
    def update(self, train_loss, train_acc, val_loss, val_acc, learning_rate):
        self.train_losses.append(train_loss)
        self.train_accs.append(train_acc)
        self.val_losses.append(val_loss)
        self.val_accs.append(val_acc)
        self.lr_history.append(learning_rate)
        
    def plot(self):
        epochs = range(1, len(self.train_losses) + 1)
        
        plt.figure(figsize=(15, 10))
        
        plt.subplot(2, 2, 1)
        plt.plot(epochs, self.train_losses, 'b-', label='Training Loss')
        plt.plot(epochs, self.val_losses, 'r-', label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(2, 2, 2)
        plt.plot(epochs, self.train_accs, 'b-', label='Training Accuracy')
        plt.plot(epochs, self.val_accs, 'r-', label='Validation Accuracy')
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy (%)')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(2, 2, 3)
        plt.plot(epochs, self.lr_history, 'g-')
        plt.title('Learning Rate over Time')
        plt.xlabel('Epochs')
        plt.ylabel('Learning Rate')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'training_history.png'))
        plt.close()
        
        history_data = {
            'train_losses': self.train_losses,
            'train_accs': self.train_accs,
            'val_losses': self.val_losses,
            'val_accs': self.val_accs,
            'lr_history': self.lr_history
        }
        
        with open(os.path.join(self.save_dir, 'training_history.json'), 'w') as f:
            json.dump(history_data, f)

def train_model():
    if not verify_dataset_structure():
        print("Dataset verification failed. Please check your directory structure.")
        return
    
    train_loader, val_loader, test_loader, idx_to_class = get_datasets_and_loaders()
    num_classes = len(idx_to_class)
    
    model = create_model(num_classes)
    model = model.to(CONFIG['device'])
    
    criterion = LabelSmoothingLoss(classes=num_classes, smoothing=0.1)
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=CONFIG['learning_rate'],
        steps_per_epoch=len(train_loader),
        epochs=CONFIG['num_epochs'],
        pct_start=0.3
    )
    
    history = TrainingHistory(CONFIG['plots_dir'])
    
    start_epoch = 0
    best_val_acc = 0.0
    
    if CONFIG['resume_checkpoint'] and os.path.exists(CONFIG['resume_checkpoint']):
        print(f"Resuming from checkpoint: {CONFIG['resume_checkpoint']}")
        checkpoint = torch.load(CONFIG['resume_checkpoint'], map_location=CONFIG['device'])
        
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        start_epoch = checkpoint.get('epoch', 0) + 1
        best_val_acc = checkpoint.get('best_val_acc', 0.0)
        
        if 'history' in checkpoint:
            history.train_losses = checkpoint['history'].get('train_losses', [])
            history.train_accs = checkpoint['history'].get('train_accs', [])
            history.val_losses = checkpoint['history'].get('val_losses', [])
            history.val_accs = checkpoint['history'].get('val_accs', [])
            history.lr_history = checkpoint['history'].get('lr_history', [])
        
        print(f"Resuming from epoch {start_epoch} with best validation accuracy: {best_val_acc:.2f}%")
    
    train_dataset = train_loader.dataset
    if hasattr(train_dataset, 'get_corrupted_image_count'):
        corrupted_count = train_dataset.get_corrupted_image_count()
        if corrupted_count > 0:
            print(f"Warning: {corrupted_count} corrupted images were skipped during training")
    
    print(f"Training on {CONFIG['device']}")
    print(f"Number of classes: {num_classes}")
    print(f"Model: EfficientNet-B2 with custom classifier")
    print(f"Batch size: {CONFIG['batch_size']}")
    print(f"Learning rate: {CONFIG['learning_rate']}")
    print(f"Weight decay: {CONFIG['weight_decay']}")
    print("Class Mapping:")
    for idx, cls_name in idx_to_class.items():
        print(f"  {idx}: {cls_name}")
    
    for epoch in range(start_epoch, CONFIG['num_epochs']):
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
        
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        batch_time = AverageMeter()
        
        end = time.time()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(CONFIG['device']), targets.to(CONFIG['device'])
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            scheduler.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
            batch_time.update(time.time() - end)
            end = time.time()
            
            if (batch_idx + 1) % CONFIG['log_interval'] == 0:
                print(f"Batch: {batch_idx+1}/{len(train_loader)} | "
                      f"Loss: {train_loss/(batch_idx+1):.4f} | "
                      f"Acc: {100.*train_correct/train_total:.2f}% | "
                      f"Time: {batch_time.avg:.3f}s/batch | "
                      f"LR: {scheduler.get_last_lr()[0]:.6f}")
        
        train_loss = train_loss/len(train_loader)
        train_acc = 100. * train_correct / train_total
        
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(CONFIG['device']), targets.to(CONFIG['device'])
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                
                for i in range(targets.size(0)):
                    label = targets[i].item()
                    class_total[label] += 1
                    if predicted[i].item() == label:
                        class_correct[label] += 1
        
        val_loss = val_loss/len(val_loader)
        val_acc = 100. * val_correct / val_total
        
        current_lr = scheduler.get_last_lr()[0]
        history.update(train_loss, train_acc, val_loss, val_acc, current_lr)
        history.plot()
        
        print(f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | "
              f"Val Acc: {val_acc:.2f}%")
        
        class_accuracies = [(i, 100.0 * class_correct[i] / max(1, class_total[i])) for i in range(num_classes)]
        worst_classes = sorted(class_accuracies, key=lambda x: x[1])[:3]
        print("Worst performing classes:")
        for idx, acc in worst_classes:
            print(f"  {idx_to_class[idx]}: {acc:.2f}%")
        
        is_best = val_acc > best_val_acc
        if is_best:
            best_val_acc = val_acc
            print(f"New best validation accuracy: {best_val_acc:.2f}%")
            
            best_model_path = os.path.join(CONFIG['checkpoint_dir'], 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
                'train_acc': train_acc,
                'best_val_acc': best_val_acc,
                'class_mapping': idx_to_class,
                'history': {
                    'train_losses': history.train_losses,
                    'train_accs': history.train_accs,
                    'val_losses': history.val_losses,
                    'val_accs': history.val_accs,
                    'lr_history': history.lr_history
                }
            }, best_model_path)
            print(f"Best model saved to {best_model_path}")
        
        epoch_checkpoint_path = os.path.join(CONFIG['checkpoint_dir'], f'epoch_{epoch+1}_checkpoint.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc,
            'class_mapping': idx_to_class,
            'history': {
                'train_losses': history.train_losses,
                'train_accs': history.train_accs,
                'val_losses': history.val_losses,
                'val_accs': history.val_accs,
                'lr_history': history.lr_history
            }
        }, epoch_checkpoint_path)
        print(f"Checkpoint saved to {epoch_checkpoint_path}")
    
    final_model_path = os.path.join(CONFIG['checkpoint_dir'], 'final_model.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'class_mapping': idx_to_class,
        'best_val_acc': best_val_acc
    }, final_model_path)
    print(f"\nFinal model saved to {final_model_path}")
    
    best_model_path = os.path.join(CONFIG['checkpoint_dir'], 'best_model.pth')
    if os.path.exists(best_model_path):
        print("\nLoading best model for test evaluation...")
        checkpoint = torch.load(best_model_path, map_location=CONFIG['device'])
        model.load_state_dict(checkpoint['model_state_dict'])
    
    print("\nEvaluating on test set...")
    model.eval()
    test_correct = 0
    test_total = 0
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(CONFIG['device']), targets.to(CONFIG['device'])
            
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()
            
            for i in range(targets.size(0)):
                label = targets[i].item()
                class_total[label] += 1
                if predicted[i].item() == label:
                    class_correct[label] += 1
    
    test_acc = 100. * test_correct / test_total
    print(f"Test Accuracy: {test_acc:.2f}%")
    
    print("\nPer-class test accuracy:")
    for i in range(num_classes):
        if class_total[i] > 0:
            class_acc = 100.0 * class_correct[i] / class_total[i]
            print(f"  {idx_to_class[i]}: {class_acc:.2f}% ({class_correct[i]}/{class_total[i]})")
    
    print("\nTraining completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Final test accuracy: {test_acc:.2f}%")

class AverageMeter:
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def plot_prediction_examples(model, test_loader, idx_to_class, num_examples=5):
    model.eval()
    
    plt.figure(figsize=(20, 4*num_examples))
    
    correct_count = 0
    incorrect_count = 0
    example_idx = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(CONFIG['device']), targets.to(CONFIG['device'])
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            for i in range(len(targets)):
                if (predicted[i] == targets[i] and correct_count < num_examples) or \
                   (predicted[i] != targets[i] and incorrect_count < num_examples):
                    
                    is_correct = predicted[i] == targets[i]
                    
                    if is_correct:
                        correct_count += 1
                    else:
                        incorrect_count += 1
                    
                    img = inputs[i].cpu().numpy().transpose(1, 2, 0)
                    mean = np.array([0.485, 0.456, 0.406])
                    std = np.array([0.229, 0.224, 0.225])
                    img = std * img + mean
                    img = np.clip(img, 0, 1)
                    
                    plt.subplot(num_examples*2, 5, example_idx + 1)
                    plt.imshow(img)
                    plt.title(f"True: {idx_to_class[targets[i].item()]}\n"
                              f"Pred: {idx_to_class[predicted[i].item()]}\n"
                              f"{'Correct' if is_correct else 'Incorrect'}")
                    plt.axis('off')
                    
                    example_idx += 1
            
            if correct_count >= num_examples and incorrect_count >= num_examples:
                break
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['plots_dir'], 'prediction_examples.png'))
    plt.close()

if __name__ == "__main__":
    if CONFIG['resume_checkpoint'] is None:
        latest_checkpoint = find_latest_checkpoint(CONFIG['checkpoint_dir'])
        if latest_checkpoint:
            print(f"Found latest checkpoint: {latest_checkpoint}")
            CONFIG['resume_checkpoint'] = latest_checkpoint
    
    start_time = time.time()
    train_model()
    
    best_model_path = os.path.join(CONFIG['checkpoint_dir'], 'best_model.pth')
    if os.path.exists(best_model_path):
        print("Loading best model for prediction examples...")
        checkpoint = torch.load(best_model_path, map_location=CONFIG['device'])
        
        num_classes = len(checkpoint['class_mapping'])
        model = create_model(num_classes)
        model = model.to(CONFIG['device'])
        model.load_state_dict(checkpoint['model_state_dict'])
        
        _, _, test_loader, _ = get_datasets_and_loaders()
        
        plot_prediction_examples(model, test_loader, checkpoint['class_mapping'])
    
    total_time = time.time() - start_time
    hours, remainder = divmod(total_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"\nTotal execution time: {int(hours)}h {int(minutes)}m {int(seconds)}s")