<a href="https://colab.research.google.com/github/KarineAyrs/science_work/blob/main/training/improved_classification_with_wandb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm

Collecting timm
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[?25l[K     |▉                               | 10 kB 22.3 MB/s eta 0:00:01[K     |█▊                              | 20 kB 10.7 MB/s eta 0:00:01[K     |██▋                             | 30 kB 8.7 MB/s eta 0:00:01[K     |███▌                            | 40 kB 7.8 MB/s eta 0:00:01[K     |████▍                           | 51 kB 5.5 MB/s eta 0:00:01[K     |█████▏                          | 61 kB 5.6 MB/s eta 0:00:01[K     |██████                          | 71 kB 5.4 MB/s eta 0:00:01[K     |███████                         | 81 kB 6.0 MB/s eta 0:00:01[K     |███████▉                        | 92 kB 6.1 MB/s eta 0:00:01[K     |████████▊                       | 102 kB 5.3 MB/s eta 0:00:01[K     |█████████▋                      | 112 kB 5.3 MB/s eta 0:00:01[K     |██████████▍                     | 122 kB 5.3 MB/s eta 0:00:01[K     |███████████▎                    | 133 kB 5.3 MB/s eta 0:00:01[K     |

In [None]:
!pip install wandb

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import wandb

In [None]:
wandb.login()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
class TimmModel(nn.Module):

    def __init__(self, model_name='efficientnet_b0', pretrained=True):

        super(TimmModel, self).__init__()
        self._model_name = model_name
        self._pretrained = pretrained
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = timm.create_model(model_name=model_name, pretrained=pretrained)
        self.model.train()
        self.config = False
        if torch.cuda.is_available():
            self.model.cuda()

    def model_config(self, learning_rate=0.001, batch_size=2, num_epochs=5, criterion=None, optimizer=None):
        self.lr = learning_rate
        self.batch_size = batch_size
        self.num_epochs = 5
        self.criterion = nn.CrossEntropyLoss() if criterion is None else criterion
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate) if optimizer is None else optimizer
        self.config = True

    def train_model(self, train_loader):
        if not self.config:
          self.model_config()

        wandb.init(project='mnist-classification')
        wandb.config={'learning_rate': self.lr, 'epochs': self.num_epochs, 'batch_size': self.batch_size}
        wandb.watch(model, criterion=self.criterion, log='all',log_freq=10)

        total_batches = len(train_loader)*self.num_epochs        
        example_ct = 0 # number of examples seen
        batch_ct = 0

        for epoch in range(1, self.num_epochs + 1, 1):

            print(f'epoch:{epoch}')

            for batch, (data, targets) in enumerate(train_loader):

                data = data.to(device=self.device)
                targets = targets.to(device=self.device)

                example_ct+=len(data)
                batch_ct+=1

                

                if not self._model_name.startswith('vit'):
                    data = data.repeat(1, 3, 1, 1)
                else:
                    data = data.repeat(1, 3, 8, 8)

                scores = self.model(data)
                loss = self.criterion(scores, targets)

                if((batch+1)%25)==0:
                  self._train_log(loss, example_ct, epoch)

                self.optimizer.zero_grad()
                loss.backward()

                self.optimizer.step()

    def _train_log(self, loss, example_ct, epoch):
      wandb.log({'epoch': epoch, 'loss':loss}, step=example_ct)
      print(f'Loss after ' + str(example_ct).zfill(5) + f' examples: {loss:.3f}')      

    def _check_acc(self, loader):
        msg = 'train' if loader.dataset.train else 'test'     

        print('Checking accuracy on ' +msg+' data')

        num_correct = 0
        num_samples = 0
        self.model.eval()

        with torch.no_grad():
            for x, y in loader:
                x = x.to(device=self.device)
                y = y.to(device=self.device)

                if not self._model_name.startswith('vit'):
                    x = x.repeat(1, 3, 1, 1)
                else:
                    x = x.repeat(1, 3, 8, 8)

                scores = self.model(x)
                _, predictions = scores.max(1)
                num_correct += (predictions == y).sum()
                num_samples += predictions.size(0)

            print(
                f'Got {num_correct}/{num_samples} with accuracy {(float(num_correct) / float(num_samples)) * 100}')
            
            wandb.log({msg+'_accuracy': float(num_correct) / float(num_samples) })

        torch.onnx.export(self.model, x, 'model.onnx')
        wandb.save('model.onnx')
        self.model.train()

    def check_accuracy(self, train_loader, test_loader):
        self._check_acc(train_loader)
        self._check_acc(test_loader)


In [None]:
class MNIST:
  def __init__(self, batch_size=2):
    self.batch_size=batch_size
    self._train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
    self._train_loader=DataLoader(dataset=self._train_dataset, batch_size=batch_size, shuffle=True)

    self._test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
    self._test_loader=DataLoader(dataset=self._test_dataset, batch_size=batch_size, shuffle=True)


  def train_loader(self):
    return self._train_loader
  
  def test_loader(self):
    return self._test_loader

In [None]:
learning_rate=0.001
batch_size=2
epochs=2


model = TimmModel('vit_tiny_patch16_224')
model.model_config(learning_rate=learning_rate,batch_size=batch_size,num_epochs=epochs)

mnist = MNIST()


train_loader = mnist.train_loader()
test_loader = mnist.test_loader()

# wandb.watch(model)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
model.train_model(train_loader=train_loader)

[34m[1mwandb[0m: Currently logged in as: [33mkarine_ayrs[0m (use `wandb login --relogin` to force relogin)


[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
Loss after 50250 examples: 0.025
Loss after 50300 examples: 0.010
Loss after 50350 examples: 0.828
Loss after 50400 examples: 0.080
Loss after 50450 examples: 1.529
Loss after 50500 examples: 3.604
Loss after 50550 examples: 0.060
Loss after 50600 examples: 3.409
Loss after 50650 examples: 0.056
Loss after 50700 examples: 0.035
Loss after 50750 examples: 1.220
Loss after 50800 examples: 1.069
Loss after 50850 examples: 0.216
Loss after 50900 examples: 0.477
Loss after 50950 examples: 0.171
Loss after 51000 examples: 0.302
Loss after 51050 examples: 0.039
Loss after 51100 examples: 1.819
Loss after 51150 examples: 0.178
Loss after 51200 examples: 2.186
Loss after 51250 examples: 1.175
Loss after 51300 examples: 0.020
Loss after 51350 examples: 0.283
Loss after 51400 examples: 0.062
Loss after 51450 examples: 0.014
Loss after 51500 examples: 0.471
Loss after 51550 examples: 0.361
Loss after 51600 examples: 

In [None]:
model.check_accuracy(train_loader, test_loader)

Checking accuracy on train data
Got 51136/60000 with accuracy 85.22666666666666


  assert H == self.img_size[0] and W == self.img_size[1], \
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


Checking accuracy on test data
Got 8585/10000 with accuracy 85.85000000000001
