In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import numpy as np
from torchvision import transforms
from tqdm import tqdm
import timm

In [2]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [3]:
# importing the zipfile module
from zipfile import ZipFile

# loading the temp.zip and creating a zip object
with ZipFile("/content/mushrooms_small.zip", 'r') as zObject:

    # Extracting all the members of the zip
    # into a specific location.
    zObject.extractall(
        path="/content")

In [4]:
class MushroomDataset(Dataset):
    """Кастомный датасет для работы с папками классов"""
    def __init__(self, root_dir, transform=None, target_size=None):
        """
        Args:
            root_dir (str): Путь к папке с классами
            transform: Первичная аугментации для изображений
            target_size (tuple): Размер для ресайза изображений
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_size = target_size

        # Получаем список классов (папок)
        self.ediable_cls = sorted([d for d in os.listdir(root_dir)   if os.path.isdir(os.path.join(root_dir, d))])
        self.ediable2idx = {cls_name: idx for idx, cls_name in enumerate(self.ediable_cls)}

        self.mushroom_cls = [os.listdir(os.path.join(root_dir, dir_name)) for dir_name in self.ediable_cls]
        self.mushroom_cls = [d for mushdir in self.mushroom_cls for d in mushdir]
        self.mushroom2idx = {cls_name: idx for idx, cls_name in enumerate(self.mushroom_cls)}


        # Собираем все пути к изображениям
        self.images = []
        self.labels: list[dict] = []

        for ed_name in self.ediable_cls:
            ediable_dir = os.path.join(root_dir, ed_name)
            ediable_id = self.ediable2idx[ed_name]

            for mush_name in os.listdir(ediable_dir):
                class_dir = os.path.join(ediable_dir, mush_name)
                mush_id = self.mushroom2idx[mush_name]

                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                        img_path = os.path.join(class_dir, img_name)
                        self.images.append(img_path)
                        self.labels.append({'ed_id': ediable_id,
                                            'mush_id': mush_id})


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Загружаем изображение
        image = Image.open(img_path).convert('RGB')

        # Ресайзим изображение
        if self.target_size:
            image = image.resize(self.target_size, Image.Resampling.LANCZOS)

        # Применяем аугментации
        if self.transform:
            image = self.transform(image)

        # image = image.view()

        '''СТАВИМ ТОЛЬКО!!! ed_id '''
        return image, label['ed_id']

    def get_mushrooms_name(self):
        """Возвращает список имен видов грибов"""
        return self.mushroom_cls

    def get_ediable_name(self):
        """Возвращает список о съедобности"""
        return self.ediable_cls

In [5]:
def get_sinusoid_encoding(num_tokens, token_len):
    """ Make Sinusoid Encoding Table

        Args:
            num_tokens (int): number of tokens
            token_len (int): length of a token

        Returns:
            (torch.FloatTensor) sinusoidal position encoding table
    """

    def get_position_angle_vec(i):
        return [i / np.power(10000, 2 * (j // 2) / token_len) for j in range(token_len)]

    sinusoid_table = np.array([get_position_angle_vec(i) for i in range(num_tokens)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class NeuralNet(nn.Module):
    def __init__(self,
       in_chan: int,
       hidden_chan =None,
       out_chan =None,
       act_layer = nn.GELU):
        """ Neural Network Module

            Args:
                in_chan (int): number of channels (features) at input
                hidden_chan (NoneFloat): number of channels (features) in the hidden layer;
                                        if None, number of channels in hidden layer is the same as the number of input channels
                out_chan (NoneFloat): number of channels (features) at output;
                                        if None, number of output channels is same as the number of input channels
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
        """

        super().__init__()

        ## Define Number of Channels
        hidden_chan = hidden_chan or in_chan
        out_chan = out_chan or in_chan

        ## Define Layers
        self.fc1 = nn.Linear(in_chan, hidden_chan)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_chan, out_chan)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

class Encoding(nn.Module):
    def __init__(self,
       dim: int,
       num_heads: int=1,
       hidden_chan_mul: float=4.,
       qkv_bias: bool=False,
       qk_scale =None,
       act_layer=nn.GELU,
       norm_layer=nn.LayerNorm):

        """ Encoding Block

            Args:
                dim (int): size of a single token
                num_heads(int): number of attention heads in MSA
                hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by;
                                    if None, queries and keys are scaled by ``head_dim ** -0.5``
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
                norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
        """

        super().__init__()

        ## Define Layers
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim,
                            num_heads=num_heads,
                            qkv_bias=qkv_bias,
                            qk_scale=qk_scale)
        self.norm2 = norm_layer(dim)
        self.neuralnet = NeuralNet(in_chan=dim,
                                hidden_chan=int(dim*hidden_chan_mul),
                                out_chan=dim,
                                act_layer=act_layer)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.neuralnet(self.norm2(x))
        return x

class Patch_Tokenization(nn.Module):
    def __init__(self,
                img_size: tuple[int, int, int]=(3, 224, 244),
                patch_size: int=16,
                token_len: int=768,
                batch: int = 32):

        """ Patch Tokenization Module
            Args:
                img_size (tuple[int, int, int]): size of input (channels, height, width)
                patch_size (int): the side length of a square patch
                token_len (int): desired length of an output token
        """
        super().__init__()

        ## Defining Parameters
        self.img_size = img_size
        C, H, W = self.img_size
        self.patch_size = patch_size
        self.token_len = token_len
        assert H % self.patch_size == 0, 'Height of image must be evenly divisible by patch size.'
        assert W % self.patch_size == 0, 'Width of image must be evenly divisible by patch size.'
        self.num_tokens = int((H / self.patch_size) * (W / self.patch_size))

        ## Defining Layers
        self.split = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size, padding=0)
        self.project = nn.Linear((self.patch_size**2)*C, token_len)

    def forward(self, x):
        print(x.size())
        x = self.split(x).transpose(1,2)
        print(x.size())
        x = self.project(x)
        # print(x, '\n', x.size())
        return x

class ViT_Backbone(nn.Module):
    def __init__(self,
                num_tokens,
                preds: int=1,
                token_len: int=768,
                num_heads: int=1,
                Encoding_hidden_chan_mul: float=4.,
                depth: int=12,
                qkv_bias=False,
                qk_scale=None,
                act_layer=nn.GELU,
                norm_layer=nn.LayerNorm):

        """ VisTransformer Backbone
            Args:
                preds (int): number of predictions to output
                token_len (int): length of a token
                num_heads(int): number of attention heads in MSA
                Encoding_hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component of the Encoding Module
                depth (int): number of encoding blocks in the model
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by;
                 if None, queries and keys are scaled by ``head_dim ** -0.5``
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
                norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
        """

        super().__init__()

        ## Defining Parameters
        self.num_heads = num_heads
        self.Encoding_hidden_chan_mul = Encoding_hidden_chan_mul
        self.depth = depth
        self.token_len = token_len
        self.num_tokens = num_tokens

        ## Defining Token Processing Components
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.token_len))
        # self.cls_token = torch.transpose(self.cls_token, 0, 1)
        self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(num_tokens=self.num_tokens+1, token_len=self.token_len), requires_grad=False)

        ## Defining Encoding blocks
        self.blocks = nn.ModuleList([Encoding(dim = self.token_len,
                                               num_heads = self.num_heads,
                                               hidden_chan_mul = self.Encoding_hidden_chan_mul,
                                               qkv_bias = qkv_bias,
                                               qk_scale = qk_scale,
                                               act_layer = act_layer,
                                               norm_layer = norm_layer)
             for i in range(self.depth)])

        ## Defining Prediction Processing
        self.norm = norm_layer(self.token_len)
        self.head = nn.Linear(self.token_len, preds)

        ## Make the class token sampled from a truncated normal distrobution
        timm.layers.trunc_normal_(self.cls_token, std=.02)

    def forward(self, x):
        ## Assumes x is already tokenized

        ## Get Batch Size
        B = x.shape[0]
        ## Concatenate Class Token
        print(x.size())
        # print(self.cls_token.expand(B, -1, -1))
        x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
        print(x.size())
        print(self.cls_token.size())
        # x = torch.transpose(x, 0, 1)
        # x = torch.cat((self.cls_token, x), dim=1)
        # x = torch.transpose(x, 0, 1)
        ## Add Positional Embedding
        x = x + self.pos_embed
        ## Run Through Encoding Blocks
        for blk in self.blocks:
            x = blk(x)
        ## Take Norm
        x = self.norm(x)
        ## Make Prediction on Class Token
        x = self.head(x[:, 0])
        return x

class MushroomViTModel(nn.Module):
    def __init__(self,
        img_size: tuple[int, int, int]=(3, 224, 224),
        patch_size: int=16,
        token_len: int=768,
        preds: int=1,
        num_heads: int=1,
        Encoding_hidden_chan_mul: float=4.,
        depth: int=12,
        qkv_bias=False,
        qk_scale=None,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm):

      """ VisTransformer Model

      Args:
        img_size (tuple[int, int, int]): size of input (channels, height, width)
        patch_size (int): the side length of a square patch
        token_len (int): desired length of an output token
        preds (int): number of predictions to output
        num_heads(int): number of attention heads in MSA
        Encoding_hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component of the Encoding Module
        depth (int): number of encoding blocks in the model
        qkv_bias (bool): determines if the qkv layer learns an addative bias
        qk_scale (NoneFloat): value to scale the queries and keys by;
            if None, queries and keys are scaled by ``head_dim ** -0.5``
        act_layer(nn.modules.activation): torch neural network layer class to use as activation
        norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
      """
      super().__init__()

      ## Defining Parameters
      self.img_size = img_size
      C, H, W = self.img_size
      self.patch_size = patch_size
      self.token_len = token_len
      self.num_heads = num_heads
      self.Encoding_hidden_chan_mul = Encoding_hidden_chan_mul
      self.depth = depth

      ## Defining Patch Embedding Module
      self.patch_tokens = Patch_Tokenization(img_size,
              patch_size,
              token_len)
      self.num_tokens = self.patch_tokens.num_tokens

      ## Defining ViT Backbone
      self.backbone = ViT_Backbone(self.num_tokens,
            preds,
            self.token_len,
            self.num_heads,
            self.Encoding_hidden_chan_mul,
            self.depth,
            qkv_bias,
            qk_scale,
            act_layer,
            norm_layer)
      ## Initialize the Weights
      self.apply(self._init_weights)

    def _init_weights(self, m):
      """ Initialize the weights of the linear layers & the layernorms
      """
      ## For Linear Layers
      if isinstance(m, nn.Linear):
        ## Weights are initialized from a truncated normal distrobution
        timm.layers.trunc_normal_(m.weight, std=.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
          ## If bias is present, bias is initialized at zero
          nn.init.constant_(m.bias, 0)
      ## For Layernorm Layers
      elif isinstance(m, nn.LayerNorm):
        ## Weights are initialized at one
        nn.init.constant_(m.weight, 1.0)
        ## Bias is initialized at zero
        nn.init.constant_(m.bias, 0)

    @torch.jit.ignore ##Tell pytorch to not compile as TorchScript
    def no_weight_decay(self):
      """ Used in Optimizer to ignore weight decay in the class token
      """
      return {'cls_token'}

    def forward(self, x):
      batch_size = x.shape[0]
      x = self.patch_tokens(x)
      x = self.backbone(x)
      return x

In [6]:
# Загрузка датасета без аугментаций с преобразованием PIL --> torch.tensor()
transform = transforms.ToTensor()

root_train = '/content/mushroom_dataset/'
data = MushroomDataset(root_train, transform=transform, target_size=(224, 224))

In [7]:
def run_epoch(model, data_loader, criterion, optimizer=None, device='cuda:0', is_test=False):
    if is_test:
        model.eval()
    else:
        model.train()

    total_loss = 0
    correct = 0
    total = 0

    # Переносим модель на устройство
    model.to(device)

    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)

        if not is_test and optimizer is not None:
            optimizer.zero_grad()
        print(data.size(), target.size())
        # output = torch.tensor([])
        # loss = torch.tensor([])
        # for img, label in  enumerate((data, target)):
        #     output_img = model(img)
        #     loss_img = criterion(output_img, label)
        #     output.append(output_img)
        #     loss.append(loss_img)
        output = model(data)
        print(output, output.size())
        loss = criterion(output, target)

        if not is_test and optimizer is not None:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)

    return total_loss / len(data_loader), correct / total


def train_model(model:nn.Module, train_loader, test_loader, epochs=10, lr=0.001, device='cuda:0'):
    criterion = nn.CrossEntropyLoss()
    # Поробуем SGD
    optimizer = optim.SGD(model.parameters(), lr=0.007, momentum=0.9, weight_decay=0.0002)

    train_losses, train_accs = [], []
    test_losses, test_accs = [], []

    best_acc = 0.0

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer, device, is_test=False)
        test_loss, test_acc = run_epoch(model, test_loader, criterion, None, device, is_test=True)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        torch.save({
                'epoch': epoch+1,
                'model_params': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_losses': train_losses,
                'train_accs': train_accs,
                'test_losses': test_losses,
                'test_accs': test_accs
            }, f'compactCNN_last_checkpoit.pt')


        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({
                'epoch': epoch+1,
                'model_params': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': test_loss,
                'accuracy': test_acc
            }, f'compactCNN_best_model.pt')


    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'test_losses': test_losses,
        'test_accs': test_accs
    }

In [8]:
train_size = int(0.8 * len(data))
test_size = len(data) - train_size

# Делим на train, test выборки
train_dataset, test_dataset = random_split(data, [train_size, test_size])

# Создаём DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [9]:
mushroom_vit = MushroomViTModel()

metric_compactcnn = train_model(mushroom_vit, train_loader, test_loader, epochs=20, lr=0.0005, device='cuda:0')

  0%|          | 0/20 [00:00<?, ?it/s]

torch.Size([32, 3, 224, 224]) torch.Size([32])
torch.Size([32, 3, 224, 224])
torch.Size([32, 196, 768])
torch.Size([32, 196, 768])
torch.Size([32, 197, 768])
torch.Size([1, 1, 768])


  0%|          | 0/20 [00:01<?, ?it/s]

tensor([[0.2444],
        [0.2446],
        [0.2806],
        [0.2308],
        [0.2481],
        [0.2965],
        [0.2244],
        [0.2512],
        [0.2248],
        [0.2834],
        [0.2662],
        [0.2825],
        [0.2675],
        [0.2270],
        [0.2385],
        [0.2270],
        [0.2723],
        [0.2589],
        [0.2688],
        [0.2539],
        [0.2353],
        [0.2257],
        [0.2779],
        [0.2647],
        [0.2458],
        [0.2185],
        [0.2629],
        [0.2507],
        [0.2556],
        [0.2423],
        [0.2514],
        [0.2577]], device='cuda:0', grad_fn=<AddmmBackward0>) torch.Size([32, 1])





RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
