In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
save_dir = '/content/drive/MyDrive/Colab Notebooks/ViT'
os.makedirs(save_dir, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# importing the zipfile module
from zipfile import ZipFile

# loading t|he temp.zip and creating a zip object
with ZipFile("/content/mushrooms_small.zip", 'r') as zObject:
    # Extracting all the members of the zip
    # into a specific location.
    zObject.extractall(
        path="/content")

In [1]:
from transformers import ViTConfig, ViTModel

# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig()

# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
model = ViTModel(configuration)

# Accessing the model configuration
configuration = model.config

In [2]:
from transformers import ViTForImageClassification, ViTConfig
import torch.nn as nn

# Загрузка предобученной ViT-Base (ImageNet)
model_name = 'google/vit-base-patch16-224-in21k'
config = ViTConfig.from_pretrained(model_name)

# Меняем количество выходных классов под ваш датасет
config.num_labels = 4  # Замените num_classes на ваше число классов грибов

# Загружаем модель с новым классификатором
model = ViTForImageClassification.from_pretrained(
    model_name,
    config=config,
    ignore_mismatched_sizes=True  # Игнорировать несовпадение размеров выходного слоя
)

# Проверяем архитектуру
print(model)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

# Подготовка датасета и тренировки

In [5]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from torchvision import transforms


class MushroomDataset(Dataset):
    """Кастомный датасет для работы с папками классов"""
    def __init__(self, root_dir, transform=None, target_size=None):
        """
        Args:
            root_dir (str): Путь к папке с классами
            transform: Первичная аугментации для изображений
            target_size (tuple): Размер для ресайза изображений
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_size = target_size

        # Получаем список классов (папок)
        self.ediable_cls = sorted([d for d in os.listdir(root_dir)   if os.path.isdir(os.path.join(root_dir, d))])
        self.ediable2idx = {cls_name: idx for idx, cls_name in enumerate(self.ediable_cls)}

        self.mushroom_cls = [os.listdir(os.path.join(root_dir, dir_name)) for dir_name in self.ediable_cls]
        self.mushroom_cls = [d for mushdir in self.mushroom_cls for d in mushdir]
        self.mushroom2idx = {cls_name: idx for idx, cls_name in enumerate(self.mushroom_cls)}


        # Собираем все пути к изображениям
        self.images = []
        self.labels: list[dict] = []

        for ed_name in self.ediable_cls:
            ediable_dir = os.path.join(root_dir, ed_name)
            ediable_id = self.ediable2idx[ed_name]

            for mush_name in os.listdir(ediable_dir):
                class_dir = os.path.join(ediable_dir, mush_name)
                mush_id = self.mushroom2idx[mush_name]

                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                        img_path = os.path.join(class_dir, img_name)
                        self.images.append(img_path)
                        self.labels.append({'ed_id': ediable_id,
                                            'mush_id': mush_id})


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Загружаем изображение
        image = Image.open(img_path).convert('RGB')

        # Ресайзим изображение
        if self.target_size:
            image = image.resize(self.target_size, Image.Resampling.LANCZOS)

        # Применяем аугментации
        if self.transform:
            image = self.transform(image)

        '''СТАВИМ ТОЛЬКО!!! ed_id '''
        return image, label['ed_id']

    def get_mushrooms_name(self):
        """Возвращает список имен видов грибов"""
        return self.mushroom_cls

    def get_ediable_name(self):
        """Возвращает список о съедобности"""
        return self.ediable_cls

In [6]:
# Загрузка датасета без аугментаций с преобразованием PIL --> torch.tensor()
transform = transforms.ToTensor()

root_train = 'mushroom_dataset'
data = MushroomDataset(root_train, transform=transform, target_size=(224, 224))

In [7]:
import copy
import torchvision.transforms as transforms


class AugmentationPipeline:
    # Конфиги
    configs = {
        'light': {
            "RandomHorizontalFlip": transforms.RandomHorizontalFlip(p=0.6),
            "RandomRotation": transforms.RandomRotation(degrees=20)},

        'medium': {
            "RandomHorizontalFlip": transforms.RandomHorizontalFlip(p=0.8),
            "RandomRotation": transforms.RandomRotation(degrees=30),
            "RandomCrop": transforms.RandomCrop(size=(224, 224), padding=20)},

        'heavy': {
            "RandomHorizontalFlip": transforms.RandomHorizontalFlip(p=0.5),
            "RandomRotation": transforms.RandomRotation(degrees=45),
            "RandomGrayscale": transforms.RandomGrayscale(p=1.0),
            "GaussBlur": transforms.GaussianBlur(kernel_size=3)}}

    def __init__(self, config=None):
        self.augmentations = {}
        if config:
            self.augmentations = copy.deepcopy(AugmentationPipeline.configs[config])

    def add_augmentation(self, name, aug):
        """Добавляет аугментацию в пайплайн"""
        self.augmentations[name] = aug

    def remove_augmentation(self, name):
        """Удаляет аугментацию из пайплайна"""
        if name in self.augmentations:
            del self.augmentations[name]

    def apply(self, image):
        """Применяет все аугментации последовательно"""
        for aug_name, aug in self.augmentations.items():
            image = aug(image)
        return image

    def __call__(self, image):
        return self.apply(image)

    def get_augmentations(self):
        """Возвращает словарь всех аугментаций"""
        return self.augmentations.copy()

    # Для использования функций
    def keys(self):
        """Возвращает словарь всех аугментаций"""
        return self.augmentations.copy()



def run_epoch(model, data_loader, criterion, transform, optimizer=None, device='cuda:0', is_test=False):
    if is_test:
        model.eval()
    else:
        model.train()

    total_loss = 0
    correct = 0
    total = 0

    model.to(device)
    light, medium, heavy = transform

    for batch_idx, (data, target) in tqdm(enumerate(data_loader)):
        # Аугментация (только для обучения)
        if not is_test:
            with torch.no_grad():
                aug1 = torch.stack([light(img) for img in data])
                aug2 = torch.stack([medium(img) for img in data])
                aug3 = torch.stack([heavy(img) for img in data])
                data = torch.cat([data, aug1, aug2, aug3])
                target = torch.cat([target, target, target, target])

        data, target = data.to(device), target.to(device)

        if not is_test and optimizer is not None:
            optimizer.zero_grad()

        # Для ViT из transformers:
        outputs = model(pixel_values=data, labels=target)
        loss = outputs.loss
        logits = outputs.logits  # Получаем предсказания

        if not is_test and optimizer is not None:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        pred = logits.argmax(dim=1)  # Изменено с output на logits
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(data_loader), correct / total



def train_model(model:nn.Module, train_loader, test_loader, epochs=10, lr=5e-5, device='cuda:0'):

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr)


    train_losses, train_accs = [], []
    test_losses, test_accs = [], []

    best_acc = 0.0
    augs = [AugmentationPipeline('light'), AugmentationPipeline('medium'), AugmentationPipeline('heavy')]

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = run_epoch(model, train_loader, criterion, augs, optimizer, device, is_test=False)
        test_loss, test_acc = run_epoch(model, test_loader, criterion, augs, None, device, is_test=True)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        torch.save({
                'epoch': epoch+1,
                'model_params': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_losses': train_losses,
                'train_accs': train_accs,
                'test_losses': test_losses,
                'test_accs': test_accs
            }, f'{save_dir}/transformer_last_checkpoit_aug.pt')


        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({
                'epoch': epoch+1,
                'model_params': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': test_loss,
                'accuracy': test_acc
            }, f'{save_dir}/transformer_best_model_aug.pt')


    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'test_losses': test_losses,
        'test_accs': test_accs
    }

In [8]:
from torch.utils.data import random_split

train_size = int(0.8 * len(data))
test_size = len(data) - train_size

# Делим на train, test выборки
train_dataset, test_dataset = random_split(data, [train_size, test_size])

# Создаём DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [9]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from tqdm import tqdm


# Оптимизатор и устройство
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
transfomers_metrics = train_model(model, train_loader, test_loader, epochs=30, lr=5e-5, device='cuda:0')

  0%|          | 0/30 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:02,  2.62s/it][A
2it [00:05,  2.54s/it][A
3it [00:08,  3.00s/it][A
4it [00:10,  2.70s/it][A
5it [00:13,  2.55s/it][A
6it [00:15,  2.44s/it][A
7it [00:17,  2.32s/it][A
8it [00:19,  2.25s/it][A
9it [00:21,  2.25s/it][A
10it [00:23,  2.21s/it][A
11it [00:26,  2.19s/it][A
12it [00:28,  2.17s/it][A
13it [00:30,  2.16s/it][A
14it [00:32,  2.16s/it][A
15it [00:34,  2.19s/it][A
16it [00:36,  2.17s/it][A
17it [00:39,  2.17s/it][A
18it [00:41,  2.16s/it][A
19it [00:43,  2.16s/it][A
20it [00:45,  2.16s/it][A
21it [00:47,  2.18s/it][A
22it [00:49,  2.17s/it][A
23it [00:52,  2.17s/it][A
24it [00:54,  2.16s/it][A
25it [00:56,  2.16s/it][A
26it [00:58,  2.16s/it][A
27it [01:00,  2.18s/it][A
28it [01:02,  2.18s/it][A
29it [01:05,  2.17s/it][A
30it [01:07,  2.17s/it][A
31it [01:09,  2.17s/it][A
32it [01:11,  2.16s/it][A
33it [01:13,  2.20s/it][A
34it [01:16,  2.19s/it][A
35it [01:18,  2.18s/it][A
36i