In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
import numpy as np
from sklearn.model_selection import KFold
from collections import deque
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from copy import deepcopy
from google.colab import drive  # For Google Drive integration



# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Hyperparameters
K_FOLDS = 5
BATCH_SIZE = 64
NUM_CLASSES = 10
MEMORY_SIZE = 512  # Size of OOD memory bank per class
ALPHA = 0.8       # Orthogonal decomposition ratio
THRESHOLD = 0.8   # Confidence threshold
PATIENCE = 5      # Early stopping patience
SAVE_DIR = '/content/drive/MyDrive/OSP_Models'  # Directory to save models

# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)

# Data Augmentation Transforms
weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(96, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

strong_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(96, padding=4),
    transforms.RandomAffine(degrees=30, translate=(0.2, 0.2)),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Rotation Transformation
class RotationTransform:
    def __init__(self):
        self.angles = [0, 90, 180, 270]

    def __call__(self, x):
        return [transforms.functional.rotate(x, angle) for angle in self.angles]

# Model Architecture
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet18(pretrained=False)
        self.resnet.fc = nn.Identity()

    def forward(self, x):
        return self.resnet(x)

class OSPModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.encoder = Encoder()
        self.classifier = nn.Linear(512, num_classes)
        self.rotation_head = nn.Linear(512, 4)
        self.ood_detector = nn.Linear(512, 1)

    def forward(self, x):
        features = self.encoder(x)
        return {
            'cls': self.classifier(features),
            'rot': self.rotation_head(features),
            'ood': torch.sigmoid(self.ood_detector(features))
        }

# OSP Components
class AOM:
    def __init__(self, num_classes, memory_size):
        self.memory_banks = [deque(maxlen=memory_size) for _ in range(num_classes)]

    def update_memory(self, features, pseudo_labels, ood_scores):
        for feat, pl, score in zip(features, pseudo_labels, ood_scores):
            if score < 0.2:  # OOD threshold
                self.memory_banks[pl].append(feat.detach().cpu())

    def get_ood_pair(self, class_id):
        if len(self.memory_banks[class_id]) == 0:
            return None
        ood_pairs = np.vstack(self.memory_banks[class_id])
        selected_index = np.random.choice(ood_pairs.shape[0])
        return ood_pairs[selected_index]

class SOR(nn.Module):
    def __init__(self, alpha):
        super().__init__()
        self.alpha = alpha

    def forward(self, z_id, z_ood):
        # Compute cosine similarity
        cos_sim = nn.functional.cosine_similarity(z_id, z_ood, dim=1,eps=1e-6)

        # Projection magnitude
        proj_mag = torch.norm(z_id, dim=1,keepdim=True) * cos_sim.unsqueeze(1)

        # Direction vectors
        z_ood_unit = z_ood / (torch.norm(z_ood, dim=1, keepdim=True) + 1e-6)
        proj = proj_mag * z_ood_unit

        # Orthogonal decomposition
        z_pruned = z_id - self.alpha * proj
        return z_pruned

# Loss Functions
class OSPLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction='batchmean')

    def forward(self, outputs, labels, pruned_outputs, ood_scores):
        # Classification loss
        loss_cls = self.ce(outputs['cls'], labels)

        # Orthogonality regularization
        loss_ortho = self.kl(
            nn.functional.log_softmax(pruned_outputs, dim=1),
            nn.functional.softmax(outputs['cls'], dim=1)
        )

        # OOD detection loss
        loss_ood = nn.BCELoss()(outputs['ood'], ood_scores)

        return loss_cls + 0.5*loss_ortho + 0.1*loss_ood

# Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return self.early_stop

# Data Preparation
def prepare_data():
    # Labeled data
    labeled_dataset = torchvision.datasets.STL10(
        root='./data', split='train', download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
        ])
    )

    # Unlabeled data (contains OOD samples)
    unlabeled_dataset = torchvision.datasets.STL10(
        root='./data', split='unlabeled', download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
        ])
    )

    return labeled_dataset, unlabeled_dataset

# Training Loop
def train_osp():
    labeled_dataset, unlabeled_dataset = prepare_data()
    kfold = KFold(n_splits=K_FOLDS, shuffle=True)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kfold.split(labeled_dataset)):
        print(f'\nFold {fold+1}/{K_FOLDS}')
        print('-'*20)

        # Data Loaders
        train_loader = DataLoader(
            Subset(labeled_dataset, train_idx),
            batch_size=BATCH_SIZE, shuffle=True
        )
        val_loader = DataLoader(
            Subset(labeled_dataset, val_idx),
            batch_size=BATCH_SIZE
        )
        unlabeled_loader = DataLoader(
            unlabeled_dataset,
            batch_size=BATCH_SIZE*5, shuffle=True
        )

        # Initialize components
        model = OSPModel(NUM_CLASSES).to(device)
        aom = AOM(NUM_CLASSES, MEMORY_SIZE)
        sor = SOR(ALPHA).to(device)
        criterion = OSPLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

        # Training history
        history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

        # Early Stopping
        early_stopping = EarlyStopping(patience=PATIENCE)

        # Best model tracking
        best_val_acc = 0
        best_model_state = None

        # Pre-training Stage
        print('Pre-training Stage:')
        for epoch in range(2):
            model.train()
            total_loss = 0
            correct = 0
            total = 0

            for batch_idx, ((inputs, labels), (unlabeled, _)) in enumerate(zip(train_loader, unlabeled_loader)):
                inputs, labels = inputs.to(device), labels.to(device)
                unlabeled = unlabeled.to(device)

                # Rotation prediction
                rotated = RotationTransform()(unlabeled)
                rotation_labels = torch.LongTensor([i//4 for i in range(4)]).repeat(unlabeled.size(0))

                # Forward pass
                rot_features = []
                for r in rotated:
                  rot_features.append(model.encoder(r.to(device)))
                rot_outputs = model.rotation_head(torch.cat(rot_features))
                outputs = model(inputs)
                # Loss calculation
                loss_cls = criterion.ce(outputs['cls'], labels)
                loss_rot = criterion.ce(rot_outputs, rotation_labels.to(device))
                loss = loss_cls + 0.5*loss_rot

                # Backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Metrics
                total_loss += loss.item()
                _, predicted = outputs['cls'].max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            # Validation
            val_loss, val_acc = validate(model, val_loader)
            if total == 0:
              print("Warning: Total samples processed is 0. Train accuracy will be set to 0.")
              train_acc = 0
            else:
              train_acc = 100. * correct / total  # Calculate train accuracy
            history['train_loss'].append(total_loss/len(train_loader))
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)

            print(f'Epoch {epoch+1}/50 | Loss: {history["train_loss"][-1]:.3f} | '
                  f'Acc: {history["train_acc"][-1]:.2f}% | Val Acc: {val_acc:.2f}%')

        # Fine-tuning Stage with OSP
        print('\nFine-tuning Stage:')
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

        for epoch in range(2):
            model.train()
            total_loss = 0
            correct = 0
            total = 0

            for batch_idx, ((inputs, labels), (unlabeled, _)) in enumerate(zip(train_loader, unlabeled_loader)):
                inputs, labels = inputs.to(device), labels.to(device)
                unlabeled = unlabeled.to(device)

                # Forward pass
                outputs = model(inputs)
                with torch.no_grad():
                    ood_outputs = model(unlabeled)

                # OOD detection
                ood_scores = (ood_outputs['ood'] < 0.5).float()
                pseudo_labels = ood_outputs['cls'].argmax(1)

                # Update AOM memory banks
                aom.update_memory(
                    model.encoder(unlabeled),
                    pseudo_labels.cpu(),
                    ood_scores.cpu()
                )

                # Get ID-OOD pairs
                batch_ood = []
                for c in labels.unique():
                    class_mask = (labels == c)
                    class_features = model.encoder(inputs[class_mask])
                    for feat in class_features:
                        ood_feat = aom.get_ood_pair(c.item())
                        if ood_feat is not None:
                            batch_ood.append(torch.tensor(ood_feat).to(device))

                if len(batch_ood) == 0:
                    continue

                # Apply SOR
                if len(batch_ood) > 0:
                  if len(batch_ood) < inputs.shape[0]:
                    padding_size =  inputs.shape[0] - len(batch_ood)
                    batch_ood.extend([batch_ood[-1]] * padding_size)  #
                  batch_ood_tensor = torch.stack(batch_ood).to(device)
                  batch_ood_tensor = batch_ood_tensor.view(inputs.shape[0], -1)
                  pruned_features = sor(
                      model.encoder(inputs),
                      batch_ood_tensor)
                  pruned_outputs = model.classifier(pruned_features)
                else:
                  pruned_outputs = outputs['cls']

                # Loss calculation
                ood_scores_resized = ood_scores[:inputs.size(0)].to(device)
                loss = criterion(outputs, labels, pruned_outputs, ood_scores_resized)

                # Backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Metrics
                total_loss += loss.item()
                _, predicted = outputs['cls'].max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            # Validation
            val_loss, val_acc = validate(model, val_loader)
            history['train_loss'].append(total_loss/len(train_loader))
            history['train_acc'].append(100.*correct/total)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)

            print(f'Epoch {epoch+1}/100 | Loss: {history["train_loss"][-1]:.3f} | '
                  f'Acc: {history["train_acc"][-1]:.2f}% | Val Acc: {val_acc:.2f}%')
            scheduler.step()

            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = deepcopy(model.state_dict())
                torch.save(best_model_state, os.path.join(SAVE_DIR, f'best_model_fold_{fold+1}.pth'))
                print(f'New best model saved with val acc: {best_val_acc:.2f}%')

            # Early Stopping Check
            if early_stopping(val_loss):
                print(f'Early stopping triggered at epoch {epoch+1}')
                break

        # Save fold results
        fold_results.append(max(history['val_acc']))

        # Plot training curves
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Val Loss')
        plt.title('Loss Curve')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(history['train_acc'], label='Train Acc')
        plt.plot(history['val_acc'], label='Val Acc')
        plt.title('Accuracy Curve')
        plt.legend()

        plt.tight_layout()
        plt.savefig(f'fold_{fold+1}_curves.png')
        plt.close()

    # Final results
    print(f'\nK-Fold Results: {fold_results}')
    print(f'Mean Accuracy: {np.mean(fold_results):.2f}% (±{np.std(fold_results):.2f})')

def validate(model, val_loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs['cls'], labels)

            total_loss += loss.item()
            _, predicted = outputs['cls'].max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return total_loss/len(val_loader), 100.*correct/total

if __name__ == '__main__':
    train_osp()

Using device: cuda

Fold 1/5
--------------------




Pre-training Stage:
Epoch 1/50 | Loss: 2.167 | Acc: 25.40% | Val Acc: 31.70%
Epoch 2/50 | Loss: 1.602 | Acc: 40.05% | Val Acc: 38.00%

Fine-tuning Stage:
Epoch 1/100 | Loss: 1.414 | Acc: 54.73% | Val Acc: 45.00%
New best model saved with val acc: 45.00%
Epoch 2/100 | Loss: 1.330 | Acc: 58.05% | Val Acc: 47.50%
New best model saved with val acc: 47.50%

Fold 2/5
--------------------
Pre-training Stage:
Epoch 1/50 | Loss: 2.089 | Acc: 27.02% | Val Acc: 30.90%
Epoch 2/50 | Loss: 1.647 | Acc: 39.60% | Val Acc: 39.70%

Fine-tuning Stage:
Epoch 1/100 | Loss: 1.404 | Acc: 56.85% | Val Acc: 47.10%
New best model saved with val acc: 47.10%
Epoch 2/100 | Loss: 1.252 | Acc: 62.15% | Val Acc: 48.90%
New best model saved with val acc: 48.90%

Fold 3/5
--------------------
Pre-training Stage:
Epoch 1/50 | Loss: 2.220 | Acc: 26.00% | Val Acc: 28.10%
Epoch 2/50 | Loss: 1.649 | Acc: 40.50% | Val Acc: 35.10%

Fine-tuning Stage:
Epoch 1/100 | Loss: 1.405 | Acc: 55.70% | Val Acc: 45.80%
New best model sav