In [2]:
from torchvision.datasets import CIFAR10
from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField, IntField
from torch.utils.data import Subset
import os
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Convert
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
import torch
import torchvision.transforms as T


write_path = '/tmp/cifar10_train.beton'


cifar_train = CIFAR10(root='./data', train=True, download=True)


dataset = Subset(cifar_train, range(len(cifar_train)))


writer = DatasetWriter(write_path, {
    'image': RGBImageField(max_resolution=32),
    'label': IntField()
})


writer.from_indexed_dataset(dataset)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 50000/50000 [00:00<00:00, 98201.32it/s] 


In [4]:
import timm
import torch.nn as nn


model = timm.create_model('deit_small_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 10)
model = model.to(device)

In [5]:
from ffcv.transforms import RandomResizedCrop, RandomHorizontalFlip
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Convert
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
from torch import float32


image_pipeline = [
    RandomResizedCropRGBImageDecoder((224, 224)),
    Cutout(16),
    ToTensor(),
    ToTorchImage(),
    Convert(float32),
    ToDevice(device)
]

label_pipeline = [
    IntDecoder(),
    ToTensor(),
    ToDevice(device)
]

train_loader = Loader(
    './data/cifar10_train.beton',
    batch_size=128,
    num_workers=4,
    order=OrderOption.RANDOM,
    drop_last=True,
    pipelines={
        'image': image_pipeline,
        'label': label_pipeline
    }
)


In [6]:
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()
        labels = labels.squeeze(1)  # Убираем размерность 1, превращаем в [128]
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, pred = output.max(1)
        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

    acc = 100. * correct / total
    print(f"Train loss: {total_loss:.3f}, Accuracy: {acc:.2f}%")


Epoch 1: 100%|██████████| 390/390 [02:22<00:00,  2.73it/s]


Train loss: 780.921, Accuracy: 24.96%


Epoch 2: 100%|██████████| 390/390 [02:18<00:00,  2.81it/s]


Train loss: 653.135, Accuracy: 38.42%


Epoch 3: 100%|██████████| 390/390 [02:18<00:00,  2.82it/s]


Train loss: 605.927, Accuracy: 43.68%


Epoch 4: 100%|██████████| 390/390 [02:19<00:00,  2.80it/s]


Train loss: 574.814, Accuracy: 46.93%


Epoch 5: 100%|██████████| 390/390 [02:20<00:00,  2.77it/s]


Train loss: 557.909, Accuracy: 48.46%


Epoch 6: 100%|██████████| 390/390 [02:18<00:00,  2.81it/s]


Train loss: 539.914, Accuracy: 50.33%


Epoch 7: 100%|██████████| 390/390 [02:19<00:00,  2.79it/s]


Train loss: 527.382, Accuracy: 51.48%


Epoch 8: 100%|██████████| 390/390 [02:19<00:00,  2.80it/s]


Train loss: 517.538, Accuracy: 52.57%


Epoch 9: 100%|██████████| 390/390 [02:19<00:00,  2.80it/s]


Train loss: 505.377, Accuracy: 53.25%


Epoch 10: 100%|██████████| 390/390 [02:18<00:00,  2.81it/s]

Train loss: 498.128, Accuracy: 54.14%





In [32]:
import os
save_dir = '../data/model_weights'
os.makedirs(save_dir, exist_ok=True)

In [33]:
model_path = os.path.join(save_dir, 'deit_small_cifar10.pth')
try:
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': acc,
    }, model_path)
    print(f"Модель сохранена в {model_path}")
except Exception as e:
    print(f"Ошибка при сохранении модели: {e}")


Модель сохранена в data/model_weights/deit_small_cifar10.pth
