### Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import AutoImageProcessor, SwinForImageClassification
from sklearn.metrics import accuracy_score

### Hyperparameters

In [None]:
batch_size = 16
learning_rate = 1e-4
num_epochs = 80
num_classes = 4

In [None]:
import wandb

sweep_config = {
    'method': 'random', 
    'metric': {'name': 'val_loss', 'goal': 'minimize'},
    'parameters': {
        'batch_size': {'values': [16, 32, 64]},
        'learning_rate': {'values': [1e-5, 1e-4, 5e-4]},
        'num_epochs': {'value': 80},  # Fixed
        'save_interval': {'value': 5},  # Fixed
        'num_classes': {'value': 4},  # Fixed
    }
}

sweep_id = wandb.sweep(sweep_config, project="swinv2-tuning")

### Data Preprocessing

In [None]:
image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-base-patch4-window8-256")
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

### Load Datasets

In [None]:
train_dataset = datasets.ImageFolder("dataset/train", transform=transform)
val_dataset = datasets.ImageFolder("dataset/val", transform=transform)
test_dataset = datasets.ImageFolder("dataset/test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Load Pre-trained Model

In [None]:
model = SwinForImageClassification.from_pretrained("microsoft/swinv2-base-patch4-window8-256", num_labels=num_classes)
model.to("cuda" if torch.cuda.is_available() else "cpu")

### Loss and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

### Training Function

In [None]:
def train(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()  # Reset gradients
        outputs = model(images).logits  # Forward pass
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        total_loss += loss.item()
    return total_loss / len(train_loader)

### Validation Function

In [None]:
def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    return total_loss / len(val_loader), accuracy

### Training and Validation Loop

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer)
    val_loss, val_accuracy = validate(model, val_loader, criterion)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


### Evaluate on Test Set

In [None]:
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {test_accuracy:.4f}")