<a href="https://colab.research.google.com/github/KarineAyrs/science_work/blob/main/training/improved_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm

In [2]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [3]:

class TimmModel(nn.Module):

    def __init__(self, model_name='efficientnet_b0', pretrained=True):

        super(TimmModel, self).__init__()
        self._model_name = model_name
        self._pretrained = pretrained
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = timm.create_model(model_name=model_name, pretrained=pretrained)
        self.model.train()
        if torch.cuda.is_available():
            self.model.cuda()

    def model_config(self, learning_rate=0.001, batch_size=2, num_epochs=5, criterion=None, optimizer=None):
        self.lr = learning_rate
        self.batch_size = batch_size
        self.num_epochs = 5
        self.criterion = nn.CrossEntropyLoss() if criterion is None else criterion
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate) if optimizer is None else optimizer

    def train_model(self, train_loader):

        self.model_config()

        for epoch in range(1, self.num_epochs + 1, 1):

            print(f'epoch:{epoch}')

            for batch, (data, targets) in enumerate(train_loader):

                data = data.to(device=self.device)
                targets = targets.to(device=self.device)

                if not self._model_name.startswith('vit'):
                    data = data.repeat(1, 3, 1, 1)
                else:
                    data = data.repeat(1, 3, 8, 8)

                scores = self.model(data)
                loss = self.criterion(scores, targets)

                self.optimizer.zero_grad()
                loss.backward()

                self.optimizer.step()

    def _check_acc(self, loader):
        if loader.dataset.train:
            print('Checking accuracy on training data')
        else:
            print('Checking accuracy o test data')
        num_correct = 0
        num_samples = 0
        self.model.eval()

        with torch.no_grad():
            for x, y in loader:
                x = x.to(device=self.device)
                y = y.to(device=self.device)

                if not self._model_name.startswith('vit'):
                    x = x.repeat(1, 3, 1, 1)
                else:
                    x = x.repeat(1, 3, 8, 8)

                scores = self.model(x)
                _, predictions = scores.max(1)
                num_correct += (predictions == y).sum()
                num_samples += predictions.size(0)

            print(
                f'Got {num_correct}/{num_samples} with accuracy {(float(num_correct) / float(num_samples)) * 100}')

        self.model.train()

    def check_accuracy(self, train_loader, test_loader):
        self._check_acc(train_loader)
        self._check_acc(test_loader)


In [4]:
class MNIST:
  def __init__(self, batch_size=2):
    self.batch_size=batch_size
    self._train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
    self._train_loader=DataLoader(dataset=self._train_dataset, batch_size=batch_size, shuffle=True)

    self._test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
    self._test_loader=DataLoader(dataset=self._test_dataset, batch_size=batch_size, shuffle=True)


  def train_loader(self):
    return self._train_loader
  
  def test_loader(self):
    return self._test_loader

In [None]:
model = TimmModel('vit_tiny_patch16_224')
mnist = MNIST()

train_loader = mnist.train_loader()
test_loader = mnist.test_loader()


In [6]:
model.train_model(train_loader=train_loader)

epoch:1
epoch:2
epoch:3
epoch:4
epoch:5


In [7]:
model.check_accuracy(train_loader, test_loader)

Checking accuracy on training data
Got 43494/60000 with accuracy 72.49
Checking accuracy o test data
Got 7228/10000 with accuracy 72.28
