In [None]:
from data_pre import poly_X_train, poly_y_train, poly_X_test, poly_y_test
import torch
from torch import nn, optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
import cv2
from torchvision.transforms import Resize, Grayscale, ToTensor, Compose, Lambda
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = Compose([
    Lambda(lambda x: ToPILImage()(x).convert('RGB')), 
    Resize((224, 224), antialias=True),  
    ToTensor(),  
    Lambda(lambda x: x.float()),  
])

X_train = torch.stack([transform(x) for x in poly_X_train]).to(device)
X_test = torch.stack([transform(x) for x in poly_X_test]).to(device)

y_train = torch.tensor(poly_y_train, dtype=torch.long).to(device)
y_test = torch.tensor(poly_y_test, dtype=torch.long).to(device)

train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(dataset=test_dataset, batch_size=8, shuffle=False)

model = models.vit_b_16(pretrained=True) 
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_losses = []
valid_losses = []
train_accuracies = []
valid_accuracies = []

def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=30):
    for epoch in range(num_epochs):
        # train
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_losses.append(running_loss / len(train_loader))
        train_accuracies.append(correct / total)

        # test
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        all_targets = []
        all_predictions = []
        with torch.no_grad():
            for images, labels in tqdm(test_loader):
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                all_targets.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())

        valid_losses.append(running_loss / len(test_loader))
        valid_accuracies.append(correct / total)

        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_losses[-1]}, Validation Loss: {valid_losses[-1]}")
        conf_matrix = confusion_matrix(all_targets, all_predictions)
        f1 = f1_score(all_targets, all_predictions, average='macro')
        total_accuracy = accuracy_score(all_targets, all_predictions)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"F1 Score: {f1}")
        print(f"Total Accuracy: {total_accuracy}")
        print(f"Confusion Matrix:\n {conf_matrix}")


    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), train_accuracies, 'b', label='Training Accuracy')
    plt.plot(range(1, num_epochs + 1), valid_accuracies, 'g', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), train_losses, 'r', label='Training Loss')
    plt.plot(range(1, num_epochs + 1), valid_losses, 'orange', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()

torch.cuda.empty_cache()
train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=30)