In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from tqdm import tqdm

BATCH_SIZE = 64
EPOCHS = 10
LR = 3e-4
IMAGE_SIZE = 224
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE

device(type='cuda')

In [3]:
images, labels = torch.load('data/cifar10_cutmix.pt')  # [N, 3, 32, 32], [N, 10]

# images = images[:5000]
# labels = labels[:5000]

if labels.ndim == 2 and labels.shape[1] > 1:
    labels = labels.argmax(dim=1)

class CutMixCIFAR10Dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].float() / 255.0  # из [0..255] в [0..1]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


In [4]:
dataset = CutMixCIFAR10Dataset(images, labels, transform=transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

model = timm.create_model('deit_tiny_patch16_224', pretrained=True, num_classes=10)
model.to(DEVICE)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)


In [5]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for imgs, lbls in tqdm(loader):
        imgs = imgs.to(DEVICE)
        lbls = lbls.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, lbls)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {epoch_loss:.4f}")

100%|██████████| 782/782 [01:03<00:00, 12.31it/s]


Epoch 1/10 - Loss: 2.3091


100%|██████████| 782/782 [01:02<00:00, 12.42it/s]


Epoch 2/10 - Loss: 2.3046


100%|██████████| 782/782 [01:04<00:00, 12.11it/s]


Epoch 3/10 - Loss: 2.3037


100%|██████████| 782/782 [01:04<00:00, 12.15it/s]


Epoch 4/10 - Loss: 2.3037


100%|██████████| 782/782 [01:03<00:00, 12.28it/s]


Epoch 5/10 - Loss: 2.3034


100%|██████████| 782/782 [01:03<00:00, 12.28it/s]


Epoch 6/10 - Loss: 2.3034


100%|██████████| 782/782 [01:03<00:00, 12.25it/s]


Epoch 7/10 - Loss: 2.3031


100%|██████████| 782/782 [01:04<00:00, 12.19it/s]


Epoch 8/10 - Loss: 2.3032


100%|██████████| 782/782 [01:04<00:00, 12.12it/s]


Epoch 9/10 - Loss: 2.3031


100%|██████████| 782/782 [01:03<00:00, 12.25it/s]

Epoch 10/10 - Loss: 2.3032



