In [1]:
!pip install pyngrok loguru safetensors

Collecting pyngrok
  Using cached pyngrok-7.0.0-py3-none-any.whl
Collecting loguru
  Obtaining dependency information for loguru from https://files.pythonhosted.org/packages/03/0a/4f6fed21aa246c6b49b561ca55facacc2a44b87d65b8b92362a8e99ba202/loguru-0.7.2-py3-none-any.whl.metadata
  Using cached loguru-0.7.2-py3-none-any.whl.metadata (23 kB)
Using cached loguru-0.7.2-py3-none-any.whl (62 kB)
Installing collected packages: pyngrok, loguru
Successfully installed loguru-0.7.2 pyngrok-7.0.0


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


class Block(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(n_in, n_out, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(n_in, n_out, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(n_in, n_out, 3, padding=1),
        )
        self.skip = (
            nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
        )
        self.fuse = nn.ReLU()

    def forward(self, x):
        return self.fuse(self.conv(x) + self.skip(x))


class DownSample(nn.Module):
    def __init__(self, n_in, n_out, n_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(n_in, n_hidden, 3, padding=1, bias=False)
        self.downconv = nn.Conv2d(n_hidden, n_hidden, 3, stride=2, bias=False)
        self.conv2 = nn.Conv2d(n_hidden, n_out, 3, padding=1, bias=False)
        self.fuse = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.downconv(x)
        x = self.conv2(x)
        return self.fuse(x)


class Net(nn.Module):
    def __init__(
        self, n_in=1, n_classes=10, depths=[3, 3, 9, 3], dims=[64, 64, 64, 64]
    ):
        super().__init__()
        self.stem = nn.Conv2d(n_in, dims[0], 3, padding=1)
        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], n_classes)
        self.downsample_layers = nn.ModuleList()
        self.stages = nn.ModuleList()
        for i in range(3):
            self.downsample_layers.append(DownSample(dims[i], dims[i + 1], 2 * dims[i]))
        for i in range(4):
            stage = nn.Sequential(*[Block(dims[i], dims[i]) for _ in range(depths[i])])
            self.stages.append(stage)

    def forward(self, x):
        x = self.stem(x)
        for i in range(3):
            x = self.stages[i](x)
            x = self.downsample_layers[i](x)
        x = self.stages[-1](x)
        x = self.norm(x.mean([-2, -1]))
        x = self.head(x)
        return F.softmax(x)

In [3]:
from typing import Any, Dict

import matplotlib.pyplot as plt
import torch
from accelerate.tracking import on_main_process
from loguru import logger
from torch.utils.tensorboard import SummaryWriter


class MixLogger:
    main_process_only = True
    @on_main_process
    def __init__(self, log_dir):
        self.writer = SummaryWriter(log_dir)
        self.logger = logger
    
    @on_main_process
    def _info(self, info):
        self.logger.info(info)

    @on_main_process
    def _log(self, tag, scalar_value, step: int = 0):
        self.writer.add_scalar(tag, scalar_value, global_step=step)

    @on_main_process
    def _log_graph(self, model, input_to_model, verbose=False, use_strict_trace=True):
        self._info("Create model graph")
        self.writer.add_graph(model, input_to_model, verbose=verbose, use_strict_trace=use_strict_trace)



In [5]:
%load_ext tensorboard
%tensorboard --logdir /kaggle/working/

import os
import torch.nn as nn
import torch.optim

from accelerate import Accelerator, notebook_launcher
from safetensors.torch import save_model

from pyngrok import conf, ngrok
ngrokToken = ""
conf.get_default().auth_token = ngrokToken
conf.get_default().monitor_thread = False
ssh_tunnels = ngrok.get_tunnels(conf.get_default())
if len(ssh_tunnels) == 0:
    ssh_tunnel = ngrok.connect(6006)
    print('address：'+ssh_tunnel.public_url)
else:
    print('address：'+ssh_tunnels[0].public_url)

def main(name, learning_rate, batch_size, epochs, mixed_precision):
    train_dataloader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "./data/",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    test_dataloader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "./data/",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    accelerator = Accelerator(mixed_precision=mixed_precision)
    
    model = Net()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    model, optimizer, train_dataloader, test_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, test_dataloader
    )
    
    total_batch_size = batch_size * accelerator.num_processes
    
    globe_step = 0
    accelerator.print("***** Running training *****")
    accelerator.print(f"  Num epochs = {epochs}")
    accelerator.print(f"  Num batches each epoch = {len(train_dataloader)}")
    accelerator.print(f"  Num Steps = {epochs*len(train_dataloader)}")
    accelerator.print(f"  Instantaneous batch size per device = {batch_size}")
    accelerator.print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    
    logger = MixLogger(name)

    data, target = next(iter(train_dataloader))
    logger._log_graph(model, data)
    
    for epoch in range(epochs):
        # train
        logger._info(f"Epoch: [{epoch+1}/{epochs}]")
        model.train()
        for step, batch in enumerate(train_dataloader):
            data, targets = batch
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, targets)
            accelerator.backward(loss)
            optimizer.step()

            lr = optimizer.param_groups[0]['lr']
            loss, current = loss.mean(), step * total_batch_size

            logger._info(f"loss: {loss:>7f}  [{current:>5d}/{len(train_dataloader.dataset):>5d}]")

            logger._log("loss/train", loss, globe_step)
            logger._log("lr/lr", lr, globe_step)
            globe_step += 1

        accelerator.wait_for_everyone()

        # test
        model.eval()
        test_loss = 0
        correct = 0
        for i, batch in enumerate(test_dataloader):
            with torch.no_grad():
                data, targets = batch
                outputs = model(data)
            correct += outputs.argmax(1).eq(targets).sum()
            test_loss += loss_fn(outputs, targets).mean()
            

        test_loss /= len(test_dataloader)
        test_acc = correct / (len(test_dataloader)*batch_size)
        logger._info(f"[{epoch+1}/{epochs}] Test Avg loss: {test_loss:>8f} Test Avg Acc: {test_acc:>8f}")
        logger._log("loss/test", test_loss, epoch)
        logger._log("acc/test", test_acc, epoch)
        
        accelerator.wait_for_everyone()

        # save model
        if accelerator.is_main_process:
            save_model(model, f"{name}-{str(epoch+1).zfill(3)}.safetensors")
            accelerator.save_state("Training_state")
    
    if accelerator.is_main_process:
        save_model(model, f"{name}.safetensors")
        accelerator.save_state("Training_state")

args = ("Net", 1e-3, 5120, 25, "fp16")
notebook_launcher(main, args, num_processes=2)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


address：https://f0e5-34-69-126-24.ngrok-free.app
Launching training on 2 GPUs.
***** Running training *****
  Num epochs = 25
  Num batches each epoch = 6
  Num Steps = 150
  Instantaneous batch size per device = 5120
  Total train batch size (w. parallel, distributed & accumulation) = 10240


[32m2023-11-13 06:21:35.170[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mCreate model graph[0m
  return F.softmax(x)
  return F.softmax(x)
[32m2023-11-13 06:21:40.573[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mEpoch: [1/25][0m
[32m2023-11-13 06:21:44.668[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.304173  [    0/60000][0m
[32m2023-11-13 06:21:47.343[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.303341  [10240/60000][0m
[32m2023-11-13 06:21:50.785[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.301650  [20480/60000][0m
[32m2023-11-13 06:21:53.469[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.297041  [30720/60000][0m
[32m2023-11-13 06:21:56.198[0m | [1mINFO    [0m | [36m__main__[0m:[36m_info[0m:[36m19[0m - [1mloss: 2.263313  [40960/60000][0m
[32m2023-