In [11]:
from torchvision.datasets import CIFAR10
from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField, IntField
from torch.utils.data import Subset
import numpy as np
import os

class CIFAR10Adapter:
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = np.array(img)  # Преобразуем PIL → numpy (H, W, 3)
        return img, label

# Путь для сохранения
write_path = './data/cifar10_train.beton'

# Загружаем CIFAR-10
cifar_train = CIFAR10(root='./data', train=True, download=True)

# Оборачиваем
dataset = CIFAR10Adapter(cifar_train)

# Создание writer-а
writer = DatasetWriter(write_path, {
    'image': RGBImageField(write_mode='smart', max_resolution=32),
    'label': IntField()
})

# Пишем данные
writer.from_indexed_dataset(dataset)


100%|██████████| 50000/50000 [00:00<00:00, 83099.26it/s] 


In [8]:
import timm
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = timm.create_model('deit_small_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 10)
model = model.to(device)

In [9]:
device

device(type='cuda')

In [12]:
from ffcv.reader import Reader

reader = Reader('./data/cifar10_train.beton')
print(reader.metadata)  # Список полей и форматы

[((1, 32, 32,   8388608), 6) ((1, 32, 32,   8391680), 9)
 ((1, 32, 32,   8394752), 9) ... ((1, 32, 32, 108948480), 9)
 ((1, 32, 32, 108951552), 1) ((1, 32, 32, 108954624), 1)]


In [5]:
from ffcv.transforms import RandomResizedCrop, RandomHorizontalFlip
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Convert
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout, NormalizeImage
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
from torch import float32
import numpy as np


image_pipeline = [
    #RandomResizedCropRGBImageDecoder((224, 224)),
    SimpleRGBImageDecoder(),
    ToTensor(),
    ToTorchImage(),
    NormalizeImage(
        mean=np.array([0.4914, 0.4822, 0.4465], dtype=np.float32),
        std=np.array([0.2023, 0.1994, 0.2010], dtype=np.float32),
        type=np.float32
    )
]

label_pipeline = [
    IntDecoder(),
    ToTensor()
]

In [15]:
import shutil
shutil.rmtree('/home/amir_ubuntu/.cache/ffcv/loader_cache', ignore_errors=True)

In [6]:
train_loader = Loader(
    './data/cifar10_train.beton',
    batch_size=32,
    num_workers=0,
    order=OrderOption.RANDOM,
    drop_last=True,
    os_cache=True,
    recompile=False,
    pipelines={
        'image': image_pipeline,
        'label': label_pipeline
    }
)


KeyboardInterrupt: 

In [13]:
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()
        labels = labels.squeeze(1)
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, pred = output.max(1)
        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

    acc = 100. * correct / total
    print(f"Train loss: {total_loss:.3f}, Accuracy: {acc:.2f}%")


Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s]Exception ignored in: <finalize object at 0x733165fd5ee0; dead>
Traceback (most recent call last):
  File "/usr/lib/python3.12/weakref.py", line 590, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amir_ubuntu/.virtualenvs/transformer-model-optimization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 268, in finalizer
    for cres in overloads.values():
KeyError: (Array(uint8, 1, 'C', True, aligned=True), Array(uint8, 1, 'C', True, aligned=True), uint32, uint32, uint32, uint32, Literal[int](0), Literal[int](0), Literal[int](1), Literal[int](1), Literal[bool](False), Literal[bool](False))
Exception ignored in: <finalize object at 0x733165c39080; dead>
Traceback (most recent call last):
  File "/usr/lib/python3.12/weakref.py", line 590, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^

KeyboardInterrupt: 

In [32]:
import os
save_dir = '../data/model_weights'
os.makedirs(save_dir, exist_ok=True)

In [33]:
model_path = os.path.join(save_dir, 'deit_small_cifar10.pth')
try:
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': acc,
    }, model_path)
    print(f"Модель сохранена в {model_path}")
except Exception as e:
    print(f"Ошибка при сохранении модели: {e}")


Модель сохранена в data/model_weights/deit_small_cifar10.pth
