In [None]:
import os
import torch
import torch.nn as nn
from transformers import SwinForImageClassification, SwinConfig
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split as tts
from PIL import Image
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix

In [None]:
# Load the Swin Transformer model
model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

# Modify the output layer for binary classification
num_classes = 2
model.config.num_labels = num_classes
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)

In [None]:
# Define transformations for your dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize with ImageNet stats
])

# Load dataset
dataset = ImageFolder('data/original-dataset', transform=transform)

In [None]:
train_dataset, test_dataset = tts(dataset, test_size=0.2)

In [None]:
len(train_dataset), len(test_dataset)
dataset.class_to_idx

In [None]:
folder_path = "data/NON_CANCER(Augmented)"

image_files = [f for f in os.listdir(folder_path)]

augmented_images = []
for image_file in image_files:
    img_path = os.path.join(folder_path, image_file)
    img = Image.open(img_path)
    augmented_images.append((transform(img), 1))
    
train_dataset = train_dataset + augmented_images
train_dataset.pop(-1)
train_dataset.pop(-1)

In [None]:
len(train_dataset), len(augmented_images)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=2e-5)

# Training loop
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images).logits
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

        # Evaluate on validation set
        true_labels = []
        predicted_labels = []
        
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).logits
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                # Store the true and predicted labels
                true_labels.extend(labels.cpu().numpy())
                predicted_labels.extend(predicted.cpu().numpy())

        print(f"Validation Accuracy: {100 * correct / total:.2f}%")
        
        # Calculate precision, recall, and F1-score
        accuracy = 100 * correct / total
        precision = precision_score(true_labels, predicted_labels, average='binary')
        recall = recall_score(true_labels, predicted_labels, average='binary')
        f1 = f1_score(true_labels, predicted_labels, average='binary')
        
        file_handler = open(f"metrics/swin-model-metrics.txt", "a")
        file_handler.write(f"Model{epoch}: Accuracy ({accuracy}) | Precision ({precision}) | Recall ({recall}) | F1 ({f1})\n")
        file_handler.close()
        
        torch.save(model.state_dict(), os.path.join('swin-transformer-models', f'model{epoch}.pth'))
        
    
# Start training
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=5)