In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm import tqdm
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = "dataset_preprocessed"
image_size = 299
batch_size = 32
import kornia.augmentation as K
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

In [3]:
num_workers = 4

image_datasets = {
    'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms),
    'test': datasets.ImageFolder(os.path.join(data_dir, 'test'), test_transforms)
}

dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
    'test': DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
}

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, ff_dim=1024, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x

In [5]:
class XceptionTransformer(nn.Module):
    def __init__(self, num_classes=1, embed_dim=512, num_heads=8, ff_dim=1024, transformer_layers=1, dropout=0.2):
        super().__init__()
        # 1. Load pretrained Xception backbone
        self.backbone = timm.create_model('xception', pretrained=True, num_classes=0, global_pool='')  # no head

        # Output feature map = (batch, 2048, 10, 10)
        self.proj_conv = nn.Conv2d(2048, embed_dim, kernel_size=1)  # reduce channels

        # 2. Transformer encoder layers
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout) for _ in range(transformer_layers)
        ])

        # 3. Classification head
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        # Extract features
        x = self.backbone.forward_features(x)  # (B, 2048, 10, 10)
        x = self.proj_conv(x)  # (B, 512, 10, 10)

        # Flatten spatial dims into sequence
        B, C, H, W = x.shape
        x = x.flatten(2).permute(0, 2, 1)  # (B, 100, 512)

        # Transformer encoding
        x = self.transformer(x)  # (B, 100, 512)

        # Pool and classify
        x = x.permute(0, 2, 1)  # (B, 512, 100)
        x = self.pool(x).squeeze(-1)  # (B, 512)
        out = self.fc(x)  # (B, 1)
        return out

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
model = XceptionTransformer().to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(),lr=1e-4,weight_decay=1e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=2,verbose=True)

print("Model ready. Parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

  model = create_fn(


Model ready. Parameters: 24090409




In [8]:
data_augmentation = nn.Sequential(
    K.RandomHorizontalFlip(p=0.5),
    K.RandomRotation(degrees=5),
    K.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, p=0.8),
).to(device)

In [9]:
import glob

def load_latest_checkpoint(model, optimizer, scheduler, checkpoint_dir="checkpoints"):
    ckpts = sorted(glob.glob(f"{checkpoint_dir}/epoch_*.pth"))

    latest_ckpt = ckpts[-1]
    checkpoint = torch.load(latest_ckpt, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"âœ… Loaded checkpoint from {latest_ckpt} (Epoch {start_epoch})")
    return start_epoch, 0.0


In [10]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=10, checkpoint_interval=1, start_epoch=0, best_acc=0.0):
    # Ensure checkpoint directory exists
    os.makedirs("checkpoints", exist_ok=True)

    best_acc = 0.0
    train_losses, val_losses = [], []

    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 40)

        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss, running_corrects, total = 0.0, 0, 0

            for inputs, labels in tqdm(dataloaders[phase], desc=f"{phase}"):
                inputs, labels = inputs.to(device), labels.float().to(device).unsqueeze(1)

                if phase == 'train':
                    inputs = data_augmentation(inputs)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    preds = torch.sigmoid(outputs) > 0.5

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.byte())
                total += labels.size(0)

            epoch_loss = running_loss / total
            epoch_acc = running_corrects.double() / total

            if phase == 'train':
                train_losses.append(epoch_loss)
            else:
                val_losses.append(epoch_loss)
                scheduler.step(epoch_loss)

            print(f"{phase} Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}")

            # ðŸ”¹ Save best model based on validation accuracy
            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_path = "checkpoints/best_model.pth"
                torch.save(model.state_dict(), best_path)
                print(f"Best model saved to {best_path} (acc: {best_acc:.4f})")

        # ðŸ”¹ Periodic checkpoint saving
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = f"checkpoints/epoch_{epoch+1:03d}.pth"
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_losses[-1],
                'val_loss': val_losses[-1],
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

    print(f"\nðŸŽ¯ Best Validation Accuracy: {best_acc:.4f}")
    return train_losses, val_losses

In [11]:
start_epoch, best_acc = load_latest_checkpoint(model, optimizer, scheduler)
train_losses, val_losses = train_model(
    model, criterion, optimizer, scheduler,
    num_epochs=5,
    checkpoint_interval=1,
    start_epoch=start_epoch,
    best_acc=best_acc
)

  checkpoint = torch.load(latest_ckpt, map_location=device)


âœ… Loaded checkpoint from checkpoints\epoch_002.pth (Epoch 2)

Epoch 3/5
----------------------------------------


train:  20%|â–ˆâ–‰        | 610/3125 [1:17:29<5:19:31,  7.62s/it] 


KeyboardInterrupt: 