In [None]:
import timm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

# === 1. Аугментации ===
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

val_test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

# === 2. Загружаем датасет ===
full_dataset = datasets.ImageFolder("dataset/", transform=train_transforms)

# === 3. Делим на train/test/val ===
total_size = len(full_dataset)
train_size = int(0.7 * total_size)
test_size = int(0.2 * total_size)
val_size = total_size - train_size - test_size  # оставшиеся 10%

train_dataset, test_dataset, val_dataset = random_split(
    full_dataset, [train_size, test_size, val_size]
)

# Меняем трансформации для test и val, чтобы не было аугментаций
test_dataset.dataset.transform = val_test_transforms
val_dataset.dataset.transform = val_test_transforms

# === 4. DataLoaders ===
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# === 5. Модель ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model("swin_base_patch4_window7_224", pretrained=True, num_classes=2)
model = model.to(device)

# === 6. Оптимизатор и loss ===
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# === 7. Тренировка с прогресс-баром ===
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]", unit="batch")
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = torch.argmax(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        train_bar.set_postfix(loss=running_loss / len(train_loader), acc=correct / total)

    # === Валидация ===
    model.eval()
    correct, total = 0, 0
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Val]", unit="batch")
    with torch.no_grad():
        for images, labels in val_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            val_bar.set_postfix(val_acc=correct / total)

# === 8. Тестирование ===
model.eval()
correct, total = 0, 0
test_bar = tqdm(test_loader, desc="Testing", unit="batch")
with torch.no_grad():
    for images, labels in test_bar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        test_bar.set_postfix(test_acc=correct / total)