In [None]:
from src.dataset_loader import get_dataloader
from src.train import train_model
from torchvision import models
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
import torch.optim as optim
import torch
import os

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_path = "datasets/dataset_prepared/train"
dataset, _ = get_dataloader(data_dir=train_path, batch_size=32)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True
)

# Modello
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

# Ottimizzatore
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training
history = train_model(model, dataloader, criterion, optimizer, device, epochs=10, use_wandb=True)

# (Se vuoi) Visualizza le metriche direttamente nel notebook
import matplotlib.pyplot as plt
plt.plot(history["f1"])
plt.title("F1-score")
plt.grid()
plt.show()
