In [None]:
# 1. Імпорти
# ===============================
import os
import pandas as pd
from PIL import Image
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from tqdm import tqdm
import copy

# ===============================
# 2. Створення датафреймів з назв файлів
# ===============================
def get_dataframe_from_folder(folder_path):
    data = []
    for file in os.listdir(folder_path):
        if file.lower().endswith(('.jpg', '.jpeg', '.png')):
            fname = file.lower()
            if fname.startswith("notsmoking"):
                label = 0
            elif fname.startswith("smoking"):
                label = 1
            else:
                continue
            data.append({
                'image_path': os.path.join(folder_path, file),
                'label': label
            })
    return pd.DataFrame(data)

train_df = get_dataframe_from_folder("/content/smoker_data/Training/Training")
val_df = get_dataframe_from_folder("/content/smoker_data/Validation/Validation")
test_df = get_dataframe_from_folder("/content/smoker_data/Testing/Testing")

# ===============================
# 3. Torch Dataset
# ===============================
class SmokerDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = Image.open(self.df.iloc[idx]['image_path']).convert('RGB')
        label = torch.tensor(self.df.iloc[idx]['label'], dtype=torch.float32)
        if self.transform:
            image = self.transform(image)
        return image, label

# ===============================
# 4. Трансформації та лоадери
# ===============================
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

train_loader = DataLoader(SmokerDataset(train_df, train_transform), batch_size=16, shuffle=True)
val_loader = DataLoader(SmokerDataset(val_df, val_transform), batch_size=16, shuffle=False)
test_loader = DataLoader(SmokerDataset(test_df, val_transform), batch_size=16, shuffle=False)

# ===============================
# 5. Перевірка класів
# ===============================
def check_class_distribution(df, name):
    print(f"\n📊 {name} class distribution:")
    counts = df['label'].value_counts().sort_index()
    for label, count in counts.items():
        cls = "Smoker" if label == 1 else "Non-Smoker"
        print(f"  {cls}: {count}")
    if len(counts) < 2:
        print("⚠️ WARNING: Only one class detected!")
    print("-" * 30)

check_class_distribution(train_df, "Train")
check_class_distribution(val_df, "Validation")
check_class_distribution(test_df, "Test")

# ===============================
# 6. Модель: Vision Transformer (ViT)
# ===============================

from torchvision.models import vit_b_16, ViT_B_16_Weights

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Використовуємо SWAG linear ваги
weights = ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1
model = vit_b_16(weights=weights)

# Замінюємо останній шар під задачу бінарної класифікації
model.heads.head = nn.Linear(model.heads.head.in_features, 1)

# Заморожуємо всі параметри крім голови
for param in model.parameters():
    param.requires_grad = False

for param in model.heads.head.parameters():
    param.requires_grad = True

model = model.to(device)

# Втрати та оптимізатор
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.heads.head.parameters(), lr=1e-4)

# ===============================
# 7. Тренування
# ===============================
def train_model(model, train_loader, val_loader, epochs=10):
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = torch.sigmoid(outputs) > 0.5
            correct += (preds == labels.bool()).sum().item()

        train_acc = correct / len(train_loader.dataset)
        train_losses.append(total_loss/len(train_loader))
        train_accuracies.append(train_acc)

        # Validation
        model.eval()
        correct = 0
        val_loss = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                preds = torch.sigmoid(outputs) > 0.5
                correct += (preds == labels.bool()).sum().item()

        val_acc = correct / len(val_loader.dataset)
        val_losses.append(val_loss/len(val_loader))
        val_accuracies.append(val_acc)

        print(f"Train Loss: {train_losses[-1]:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), "best_model.pt")
            print("✅ Model saved.\n")

    model.load_state_dict(best_model_wts)

    # Plot Loss & Accuracy
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title("Loss per Epoch")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(train_accuracies, label='Train Acc')
    plt.plot(val_accuracies, label='Val Acc')
    plt.title("Accuracy per Epoch")
    plt.legend()
    plt.show()

# ===============================
# 8. Запуск навчання
# ===============================
train_model(model, train_loader, val_loader, epochs=15)

# Завантаження найкращої моделі
model.load_state_dict(torch.load("best_model.pt"))

# ===============================
# 9. Тестування + ROC
# ===============================
model.eval()
all_preds = []
all_probs = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images).squeeze(1)
        probs = torch.sigmoid(outputs)
        preds = probs > 0.5

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy().astype(int))
        all_labels.extend(labels.numpy().astype(int))

print("\n🧪 Test Set Evaluation:")
print(classification_report(all_labels, all_preds, target_names=["Smoker", "Non-Smoker"]))

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Non-Smoker", "Smoker"], yticklabels=["Non-Smoker", "Smoker"])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()

# ROC-крива
fpr, tpr, _ = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6,6))
plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid()
plt.show()
