In [6]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import timm
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Загрузка CutMix датасета ===
data = torch.load("../data/cifar10_cutmix.pt")
images, labels = data

train_loader = DataLoader(
    TensorDataset(images, labels),
    batch_size=64,
    shuffle=True
)


In [7]:
# === Модель DeiT-S ===
model = timm.create_model("deit_small_patch16_224", pretrained=True, num_classes=10)
model.to(device)

from experiments.EMA_for_weights import EMA

ema = EMA(model, decay=0.999)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# === Тренировка ===
epochs = 5
model.train()
for epoch in range(epochs):
    total_loss = 0
    total_correct = 0
    total_samples = 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for x, y in loop:
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x, y = x.to(device), y.to(device)

        if y.ndim > 1:
            y = torch.argmax(y, dim=1)
        y = y.long()

        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        ema.update(model)

        total_loss += loss.item()

        preds = out.argmax(dim=1)
        correct = (preds == y).sum().item()
        total_correct += correct
        total_samples += y.size(0)

        accuracy = 100 * total_correct / total_samples
        loop.set_postfix(loss=loss.item(), accuracy=f"{accuracy:.2f}%")

    avg_loss = total_loss / len(train_loader)
    avg_accuracy = 100 * total_correct / total_samples
    print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_accuracy:.2f}%")


Epoch 1/5: 100%|██████████| 782/782 [02:44<00:00,  4.76it/s, accuracy=81.05%, loss=0.266]


[Epoch 1] Avg Loss: 0.5741, Avg Accuracy: 81.05%


Epoch 2/5: 100%|██████████| 782/782 [02:39<00:00,  4.90it/s, accuracy=91.90%, loss=0.0616]


[Epoch 2] Avg Loss: 0.2440, Avg Accuracy: 91.90%


Epoch 3/5: 100%|██████████| 782/782 [02:39<00:00,  4.89it/s, accuracy=95.21%, loss=0.0885]


[Epoch 3] Avg Loss: 0.1419, Avg Accuracy: 95.21%


Epoch 4/5: 100%|██████████| 782/782 [02:40<00:00,  4.88it/s, accuracy=96.63%, loss=0.37]   


[Epoch 4] Avg Loss: 0.1004, Avg Accuracy: 96.63%


Epoch 5/5: 100%|██████████| 782/782 [02:40<00:00,  4.88it/s, accuracy=97.29%, loss=0.0323] 

[Epoch 5] Avg Loss: 0.0806, Avg Accuracy: 97.29%





In [8]:
# === Сохраняем модель ===
torch.save(model.state_dict(), "../data/model_weights/deit_s_cifar10_aug.pt")
torch.save(ema.state_dict(), "../data/model_weights/deit_s_cifar10_aug_ema.pt")
