In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix
import seaborn as sns
from transformers import DeiTForImageClassificationWithTeacher

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths and parameters
data_dir = 'path_to_your_dataset'  # Your dataset path (5 subfolder classes)
run_folder = 'results/run1'        # Results folder (run1, run2,... etc.)
batch_size = 16
num_epochs = 50  # Max epochs
patience = 5     # Early stopping patience
image_size = 224  # Input size for DeiT and ResNet
best_val_loss = float('inf')
early_stop_counter = 0

# Ensure run_folder exists
if not os.path.exists(run_folder):
    os.makedirs(run_folder)

# Transformations: Random crop for training, Resize for validation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load datasets
dataset = datasets.ImageFolder(root=data_dir)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define the teacher model (ResNet)
teacher_model = models.resnet50(pretrained=True)
teacher_model.fc = nn.Linear(teacher_model.fc.in_features, 5)  # Assuming 5 classes
teacher_model = teacher_model.to(device)
teacher_model.eval()  # Teacher should be in evaluation mode

# Define the DeiT model with knowledge distillation
model = DeiTForImageClassificationWithTeacher.from_pretrained(
    'facebook/deit-base-distilled-patch16-224', num_labels=5)
model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Track training history
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []

# Early stopping initialization
early_stop = False

# Training loop with early stopping
for epoch in range(num_epochs):
    if early_stop:
        print("Early stopping triggered. Training stopped.")
        break

    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Training phase
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Get teacher model's logits (distillation step)
        with torch.no_grad():
            teacher_logits = teacher_model(images)

        # Forward pass
        outputs = model(images, teacher_logits=teacher_logits).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_acc = 100 * val_correct / val_total
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)

    # Print metrics for this epoch
    print(f'Epoch {epoch + 1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # Check for early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        # Save the best model weights
        torch.save(model.state_dict(), os.path.join(run_folder, 'best_model_weights.pth'))
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            early_stop = True

    # Confusion matrix for validation set
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=dataset.classes, yticklabels=dataset.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix (Epoch {epoch + 1})')
    plt.savefig(os.path.join(run_folder, f'confusion_matrix_epoch_{epoch + 1}.png'))
    plt.close()

# Save final model weights (after training completes or early stops)
model_save_path = os.path.join(run_folder, 'final_model_weights.pth')
torch.save(model.state_dict(), model_save_path)
print(f'Final model weights saved to {model_save_path}')

# Plot accuracy and loss graphs
plt.figure()
plt.plot(train_acc_history, label='Train Accuracy')
plt.plot(val_acc_history, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.savefig(os.path.join(run_folder, 'accuracy_graph.png'))
plt.close()

plt.figure()
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig(os.path.join(run_folder, 'loss_graph.png'))
plt.close()

print(f'Accuracy and loss graphs saved to {run_folder}')
