In [1]:
!pip install pyngrok loguru safetensors

Collecting pyngrok
  Downloading pyngrok-7.0.0.tar.gz (718 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m718.7/718.7 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
[?25hCollecting loguru
  Obtaining dependency information for loguru from https://files.pythonhosted.org/packages/03/0a/4f6fed21aa246c6b49b561ca55facacc2a44b87d65b8b92362a8e99ba202/loguru-0.7.2-py3-none-any.whl.metadata
  Downloading loguru-0.7.2-py3-none-any.whl.metadata (23 kB)
Downloading loguru-0.7.2-py3-none-any.whl (62 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pyngrok
  Building wheel for pyngrok (setup.py) ... [?25l- \ done
[?25h  Created wheel for pyngrok: filename=pyngrok-7.0.0-py3-none-any.whl size=21129 sha256=40cb48ff52bba720710087cad96a8078bff6b1ee2e5be4259575651312148740
  Stored in direct

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


class Block(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(n_in, n_out, 3, padding=1),
            nn.BatchNorm2d(n_out),
            nn.ReLU(),
            nn.Conv2d(n_in, n_out, 3, padding=1),
            nn.BatchNorm2d(n_out)
        )
        self.skip = (
            nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
        )
        self.fuse = nn.ReLU()

    def forward(self, x):
        return self.fuse(self.conv(x) + self.skip(x))


class DownSample(nn.Module):
    def __init__(self, n_in, n_out, n_hidden):
        super().__init__()
        self.downconv = nn.Conv2d(n_in, n_out, 3, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(n_out)
        self.fuse = nn.ReLU()

    def forward(self, x):
        x = self.downconv(x)
        x = self.bn(x)
        return self.fuse(x)


class Net(nn.Module):
    def __init__(
        self, n_in=3, n_classes=10, depths=[3, 3, 9, 3], dims=[32, 64, 128, 64]
    ):
        super().__init__()
        self.stem = nn.Conv2d(n_in, dims[0], 3, padding=1)
        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], n_classes)
        self.downsample_layers = nn.ModuleList()
        self.stages = nn.ModuleList()
        for i in range(3):
            self.downsample_layers.append(DownSample(dims[i], dims[i + 1], 2 * dims[i]))
        for i in range(4):
            stage = nn.Sequential(*[Block(dims[i], dims[i]) for _ in range(depths[i])])
            self.stages.append(stage)

    def forward(self, x):
        x = self.stem(x)
        for i in range(3):
            x = self.stages[i](x)
            x = self.downsample_layers[i](x)
        x = self.stages[-1](x)
        x = self.norm(x.mean([-2, -1]))
        x = self.head(x)
        return F.softmax(x, dim=-1)

In [3]:
from typing import Any, Dict

import matplotlib.pyplot as plt
import torch
from accelerate.tracking import on_main_process
from loguru import logger
from torch.utils.tensorboard import SummaryWriter


class MixLogger:
    main_process_only = True
    @on_main_process
    def __init__(self, log_dir):
        self.writer = SummaryWriter(log_dir)
        self.logger = logger
    
    @on_main_process
    def _info(self, info):
        self.logger.info(info)

    @on_main_process
    def _log(self, tag, scalar_value, step: int = 0):
        self.writer.add_scalar(tag, scalar_value, global_step=step)

    @on_main_process
    def _log_graph(self, model, input_to_model, verbose=False, use_strict_trace=True):
        self._info("Create model graph")
        self.writer.add_graph(model, input_to_model, verbose=verbose, use_strict_trace=use_strict_trace)



In [4]:
torchvision.datasets.CIFAR10(
            "./data/",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            ),
        )

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 57670636.84it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data/


Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ./data/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
           )

In [5]:
%load_ext tensorboard
%tensorboard --logdir /kaggle/working/

import os
import torch.nn as nn
import torch.optim

from accelerate import Accelerator, notebook_launcher
from safetensors.torch import save_model

from pyngrok import conf, ngrok
ngrokToken = ""
conf.get_default().auth_token = ngrokToken
conf.get_default().monitor_thread = False
ssh_tunnels = ngrok.get_tunnels(conf.get_default())
if len(ssh_tunnels) == 0:
    ssh_tunnel = ngrok.connect(6006)
    print('address：'+ssh_tunnel.public_url)
else:
    print('address：'+ssh_tunnels[0].public_url)

def main(name, learning_rate, batch_size, epochs, mixed_precision):
    train_dataloader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(
            "./data/",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    test_dataloader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(
            "./data/",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    accelerator = Accelerator(mixed_precision=mixed_precision)
    
    model = Net()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    model, optimizer, train_dataloader, test_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, test_dataloader
    )
    
    total_batch_size = batch_size * accelerator.num_processes
    
    globe_step = 0
    accelerator.print("***** Running training *****")
    accelerator.print(f"  Num epochs = {epochs}")
    accelerator.print(f"  Num batches each epoch = {len(train_dataloader)}")
    accelerator.print(f"  Num Steps = {epochs*len(train_dataloader)}")
    accelerator.print(f"  Instantaneous batch size per device = {batch_size}")
    accelerator.print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    
    logger = MixLogger(name)

    
    for epoch in range(epochs):
        # train
        logger._info(f"Epoch: [{epoch+1}/{epochs}]")
        model.train()
        for step, batch in enumerate(train_dataloader):
            data, targets = batch
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, targets)
            accelerator.backward(loss)
            optimizer.step()

            lr = optimizer.param_groups[0]['lr']
            loss, current = loss.mean(), step * total_batch_size
            train_acc = outputs.argmax(1).eq(targets).sum() / (len(batch)*batch_size) * accelerator.num_processes

            logger._info(f"loss: {loss:>7f}  Acc: {train_acc:>7f}  [{current:>5d}/{len(train_dataloader.dataset):>5d}]")

            logger._log("loss/train", loss, globe_step)
            logger._log("lr/lr", lr, globe_step)
            globe_step += 1

        accelerator.wait_for_everyone()

        # test
        model.eval()
        test_loss = 0
        test_acc = 0
        for i, batch in enumerate(test_dataloader):
            with torch.no_grad():
                data, targets = batch
                outputs = model(data)
            test_acc += outputs.argmax(1).eq(targets).sum() / (len(batch)*batch_size) * accelerator.num_processes
            test_loss += loss_fn(outputs, targets).mean()
            

        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)
        logger._info(f"[{epoch+1}/{epochs}] Test Avg loss: {test_loss:>8f}  Test Avg Acc: {test_acc:>8f}")
        logger._log("loss/test", test_loss, epoch)
        logger._log("acc/test", test_acc, epoch)
        
        accelerator.wait_for_everyone()

        # save model
        if accelerator.is_main_process:
            save_model(model, f"{name}-{str(epoch+1).zfill(3)}.safetensors")
            accelerator.save_state("Training_state")
    
    if accelerator.is_main_process:
        save_model(model, f"{name}.safetensors")
        accelerator.save_state("Training_state")

args = ("Net", 1e-3, 6250, 125, "fp16")
notebook_launcher(main, args, num_processes=2)

address：https://0708-34-75-40-165.ngrok-free.app
Launching training on 2 GPUs.
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
***** Running training *****
  Num epochs = 125
  Num batches each epoch = 4
  Num Steps = 500
  Instantaneous batch size per device = 6250
  Total train batch size (w. parallel, distributed & accumulation) = 12500


[32m2023-11-14 02:36:43.686[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mEpoch: [1/125][0m
[32m2023-11-14 02:36:53.341[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.302068  Acc: 0.106400  [    0/50000][0m
[32m2023-11-14 02:36:56.451[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.283586  Acc: 0.144000  [12500/50000][0m
[32m2023-11-14 02:36:59.500[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.237179  Acc: 0.228640  [25000/50000][0m
[32m2023-11-14 02:37:00.920[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.221783  Acc: 0.233760  [37500/50000][0m
[32m2023-11-14 02:37:02.954[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1m[1/125] Test Avg loss: 2.306153  Test Avg Acc: 0.095360[0m
[32m2023-11-14 02:37:03.124[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19