In [None]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
import random
import torchvision.transforms as transforms
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report

import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [None]:
# Training function.
def train(epoch, model, loader, criterion, optimizer, device='cpu'):
    l = 0
    for data in tqdm(loader, desc=f'Epoch {epoch+1:03d}'):
        x = data[0].to(device)
        y = data[1].squeeze().to(device)
        out = model(x)
        loss = criterion(out, y)
        l += loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return l

# Test function.
def test(model, loader, criterion, device='cpu'):
    l = 0
    correct = 0
    total = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for data in loader:
            x = data[0].to(device)
            y = data[1].squeeze().to(device)
            out = model(x)
            l += criterion(out, y)
            _, pred = torch.max(out.data, 1)
            total += y.size(0)
            correct += (pred == y).sum().item()
            y_true += y.tolist()
            y_pred += pred.tolist()
    return l, correct / total, y_true, y_pred

In [None]:
class ImmaginiDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        
        # Carica le immagini e le etichette
        for file in os.listdir(root_dir):
            if file.endswith(".jpg") or file.endswith(".png"):
                img_path = os.path.join(root_dir, file)
                label = int(file.split("_")[0])  # supponendo che le etichette siano nel nome file
                self.images.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image, label

# Impostazioni
root_dir = "path/to/your/images"
train_dir = os.path.join(root_dir, "train")
test_dir = os.path.join(root_dir, "test")

# Transformation per il training
train_transformations = transforms.Compose([
    transforms.Resize((224, 224)),  # Adatta le immagini alla dimensione richiesta
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# Transformation per il test
test_transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Creazione dei dataset
train_dataset = ImmaginiDataset(train_dir, transform=train_transformations)
test_dataset = ImmaginiDataset(test_dir, transform=test_transformations)

# Creazione dei dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


In [None]:
# Training transformations.
train_transformations = transforms.Compose([
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# Datasets.
train_dataset = DermaMNIST(split='train',
                           transform=train_transformations,
                           download=True,
                           size=224)
test_dataset = DermaMNIST(split='test',
                          transform=transforms.ToTensor(),
                          download=True,
                          size=224)

# Loaders.
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=64,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=64,
                                          shuffle=False)