VIT

In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m935.7 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary


class PatchEmbedding(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.position_embeddings=nn.Parameter(torch.rand((img_size//patch_size)**2+1, embed_dim))
        self.cls_token=nn.Parameter(torch.randn(1,1,embed_dim))


        self.patch_embeddings = nn.Conv2d(in_chans,out_channels=embed_dim,
                                          kernel_size=patch_size,
                                          stride=patch_size)

        self.projection=nn.Sequential(
            nn.Conv2d(in_chans,embed_dim, kernel_size=patch_size,stride=patch_size),
            Rearrange('b e h w -> b (h w) e'),
        )


    def forward(self, image):


        b,c,h,w=image.shape
        x=self.projection(image)


        cls_tokens=repeat(self.cls_token, '() n e -> b n e', b = b)
        #patches = self.patch_embeddings(image).flatten(2).transpose(1,2)
        x = torch.cat([cls_tokens,x],dim=1)


        x+=self.position_embeddings

        return x #patches


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()


        out_features = out_features or in_features
        # Linear Layers
        hidden_features = hidden_features or in_features
        self.fc1= nn.Linear(in_features, hidden_features)
        self.fc2= nn.Linear(hidden_features, out_features)
        # Activation(s)
        self.drop = nn.Dropout(drop)
        self.act=nn.ReLU()

    def forward(self, x):

        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x



class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., out_drop=0.):
        super().__init__()
        self.dim = dim

        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5



        self.qkv = nn.Linear(dim,dim*3)

        # reshape
        # q&kT


        self.attn_drop = nn.Dropout(attn_drop)
        self.out = nn.Linear(dim,dim)
        self.out_drop = nn.Dropout(out_drop)

    def forward(self, x):
        B, N, C = x.shape
        # Attention
        x = self.qkv(x)

        x = torch.reshape(x,( B, N, 3, self.num_heads, self.head_dim))

        q = x[:,:,0]
        k = x[:,:,1]
        v = x[:,:,2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # Out projection


        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.out(x)
        x = self.out_drop(x)


        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()

        # Normalization
        self.norm1=nn.LayerNorm(dim)
        self.norm2=nn.LayerNorm(dim)
        # Attention
        self.attention = Attention(dim)

        # Dropout
        self.drop = nn.Dropout(drop_rate)
        # Normalization


        # MLP
        self.mlp = MLP(in_features=dim, hidden_features=3072, out_features=dim)


    def forward(self, x):
        x_plus = x
        x = self.norm1(x)
        # Attetnion
        x = self.attention(x)

        x = self.drop(x)
        x = x + x_plus
        x_plus = x
        # MLP
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop(x)
        x = x + x_plus
        return x

class Transformer(nn.Module):
    def __init__(self, depth, dim, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, mlp_ratio, drop_rate)
            for i in range(depth)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x


from torch.nn.modules.normalization import LayerNorm

class ViT(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
                 qkv_bias=False, drop_rate=0.,):
        super().__init__()

        # Присвоение переменных
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.depth = depth
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        # Path Embeddings, CLS Token, Position Encoding
        self.patch_embedding = PatchEmbedding(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)

        # Transformer Encoder
        self.transformer = Transformer(depth=self.depth, dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, drop_rate=0.)

        # Classifier
        self.classifier = nn.Linear(in_features=self.embed_dim, out_features=self.num_classes)

    def forward(self, x):

        # Path Embeddings, CLS Token, Position Encoding
        x = self.patch_embedding(x)

        # Transformer Encoder
        x = self.transformer(x)

        # Classifier
        x = self.classifier(x[:,0])




        return x

init.py

In [None]:
batch_size = 4
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

dims = (3,32,32)
num_classes = 10

data.py

In [None]:
!pip install lightning

In [None]:
import sys
sys.path.append(".")

#import classificator
import torch
import torchvision
import torchvision.transforms as transforms

import os
import lightning as L
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split


from torchvision.datasets import CIFAR10

PATH_DATASETS = os.environ.get('PATH_DATASETS',".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64



class L_data_module(L.LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
        )
        self.dims = dims
        self.num_classes = num_classes

    def prepare_data(self):
        CIFAR10(self.data_dir,train=True, download=True)
        CIFAR10(self.data_dir,train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)

            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000,500])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)

model.py

In [None]:
# import sys
# sys.path.append(".")

# import classificator


import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L
from torchmetrics.functional import accuracy


class L_model(L.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64,lr=1e-4):
        super().__init__()

        # self.channels = channels
        # self.width = width
        # self.height = height
        # self.hidden_size = hidden_size


        self.num_classes = num_classes

        self.lr = lr

        # self.model = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(channels*width*height,hidden_size),
        #     nn.ReLU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(hidden_size,hidden_size),
        #     nn.ReLU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(hidden_size, num_classes),
        # )

        self.model = nn.Sequential(
            ViT(img_size=dims[-1], patch_size=16, in_chans=3, num_classes=self.num_classes,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
                 qkv_bias=False, drop_rate=0.,)
        )
    def forward(self,x):
        x = self.model(x)

        return F.log_softmax(x,dim=1)
    def training_step(self,batch):
        x,y = batch

        logits = self(x)
        #print(logits.shape, y.shape)
        loss = F.nll_loss(logits,y)

        return loss
    def validation_step(self, batch, batch_idx):
        x,y = batch

        logits = self(x)
        #print(logits.shape, y.shape)
        loss = F.nll_loss(logits,y)
        preds = torch.argmax(logits,dim=1)
        acc = accuracy(preds,y, task='multiclass',num_classes=10)

        self.log('val_loss',loss, prog_bar=True)
        self.log('val_acc', acc,prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer



train.py

In [None]:
!pip install wandb

4005e813f50f67317738fcec0baf6f1d026840f6

In [None]:
import sys
sys.path.append(".")

# import classificator.data as d
# import classificator.model as model

import torch.nn as nn
import torch.optim as optim

import lightning as L

import wandb

from pytorch_lightning.loggers import WandbLogger
run = wandb.init(project="cool-girly-pytorch-project")

wandb_logger = WandbLogger(name='ps4lr1-4',project="cool-girly-pytorch-project", log_model='all')

dm = L_data_module()
model = L_model(3,32,32,num_classes=num_classes)
#model = vit
trainer = L.Trainer(
    max_epochs=5,
    accelerator='auto',
    devices=1,
    logger=wandb_logger,
)
trainer.fit(model,dm)


<IPython.core.display.Javascript object>