<a href="https://colab.research.google.com/github/DayDreamChaser/torch-action/blob/main/pytorch_DDP_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES
# TO THE CORRECT LOCATION (/kaggle/input) IN YOUR NOTEBOOK,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'pydata:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F4530293%2F7749135%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240303%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240303T092102Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D84f4378d65f7fc8d0dd430c421206cc61d8354316f0e551722866bec1a6db208de23539ad5305b50893112aab213e18a127f904bad0a4a1c138b4a19e72c25a6683ec9e4f142d3f70d5fed041f0c8c7d60419249402070fc53ac0132f8f3c1e4bb2e0c80ad41e95b29281fc750d03750bffbcbc66efa74bb229856eef66948c93b1132cd083836d809c5f3941d7dfe613f23de444f487c6177eb56a19ab97b2a56764c46ff38d742e455cb585b9e590590d66371df782cf7af993ee31bca94baad38483cc1c14504c97803c8ba15c96d242884814408a8e5b984ede807cefdcb058c04469cbb6ea79a218242ea891c3b68d349ceffd3f1202bd52f9a6bd6cb1c'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

learning_rate = 1e-3
batch_size = 64
epochs = 3

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

Using cuda device
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 17077834.30it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 271325.90it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 4994955.96it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 10678673.09it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
            nn.Dropout(p=0.2)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader, start=1):
        X, y = X.to(device), y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(dim=-1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)k
print("Done!")

Epoch 1
-------------------------------
loss: 0.992364  [ 6400/60000]
loss: 0.897658  [12800/60000]
loss: 0.590443  [19200/60000]
loss: 0.959099  [25600/60000]
loss: 0.425988  [32000/60000]
loss: 0.526136  [38400/60000]
loss: 0.861878  [44800/60000]
loss: 0.628830  [51200/60000]
loss: 0.629240  [57600/60000]
Test Error: 
 Accuracy: 83.6%, Avg loss: 0.454017 

Epoch 2
-------------------------------
loss: 0.581930  [ 6400/60000]
loss: 0.594764  [12800/60000]
loss: 0.632375  [19200/60000]
loss: 0.781220  [25600/60000]
loss: 0.426663  [32000/60000]
loss: 0.464518  [38400/60000]
loss: 0.947383  [44800/60000]
loss: 0.528568  [51200/60000]
loss: 0.494135  [57600/60000]
Test Error: 
 Accuracy: 85.2%, Avg loss: 0.410554 

Epoch 3
-------------------------------
loss: 0.610657  [ 6400/60000]
loss: 0.697756  [12800/60000]
loss: 0.468149  [19200/60000]
loss: 0.720574  [25600/60000]
loss: 0.367921  [32000/60000]
loss: 0.493198  [38400/60000]
loss: 0.697919  [44800/60000]
loss: 0.519677  [51200/600

In [None]:
import os,PIL
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn

import torchvision
from torchvision import transforms
import datetime

#======================================================================
# import accelerate
from accelerate import Accelerator
from accelerate.utils import set_seed
#======================================================================


def create_dataloaders(batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])

    ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

    dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False,
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val


def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=512,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("conv2",nn.Conv2d(in_channels=512,out_channels=256,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(256,128))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(128,10))
    return net



def training_loop(epochs = 5,
                  lr = 1e-3,
                  batch_size= 1024,
                  ckpt_path = "checkpoint.pt",
                  mixed_precision="no", #'fp16'
                 ):

    train_dataloader, eval_dataloader = create_dataloaders(batch_size)
    model = create_net()


    optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=25*lr,
                              epochs=epochs, steps_per_epoch=len(train_dataloader))

    #======================================================================
    # initialize accelerator and auto move data/model to accelerator.device
    set_seed(42)
    accelerator = Accelerator(mixed_precision=mixed_precision)
    accelerator.print(f'device {str(accelerator.device)} is used!')
    model, optimizer,lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer,lr_scheduler, train_dataloader, eval_dataloader)
    #======================================================================


    for epoch in range(epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            features,labels = batch
            preds = model(features)
            loss = nn.CrossEntropyLoss()(preds,labels)

            #======================================================================
            #attention here!
            accelerator.backward(loss) #loss.backward()
            #======================================================================

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()



        model.eval()
        accurate = 0
        num_elems = 0

        for _, batch in enumerate(eval_dataloader):
            features,labels = batch
            with torch.no_grad():
                preds = model(features)
            predictions = preds.argmax(dim=-1)

            #======================================================================
            #gather data from multi-gpus (used when in ddp mode)
            predictions = accelerator.gather(predictions)
            labels = accelerator.gather(labels)
            #======================================================================

            accurate_preds =  (predictions==labels)
            num_elems += accurate_preds.shape[0]
            accurate += accurate_preds.long().sum()

        eval_metric = accurate.item() / num_elems

        #======================================================================
        #print logs and save ckpt
        accelerator.wait_for_everyone()
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        accelerator.print(f"epoch【{epoch}】@{nowtime} --> eval_metric= {100 * eval_metric:.2f}%")
        unwrapped_net = accelerator.unwrap_model(model)
        accelerator.save(unwrapped_net.state_dict(),ckpt_path+"_"+str(epoch))
        #======================================================================

training_loop(epochs = 5,lr = 1e-3,batch_size= 1024,ckpt_path = "checkpoint.pt",
            mixed_precision="no") #mixed_precision='fp16'


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./minist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 115489370.00it/s]


Extracting ./minist/MNIST/raw/train-images-idx3-ubyte.gz to ./minist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./minist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 45149345.44it/s]

Extracting ./minist/MNIST/raw/train-labels-idx1-ubyte.gz to ./minist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./minist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 28517728.59it/s]


Extracting ./minist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./minist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./minist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8172685.01it/s]


Extracting ./minist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./minist/MNIST/raw

device cuda is used!
epoch【0】@2024-03-03 09:08:10 --> eval_metric= 10.18%
epoch【1】@2024-03-03 09:08:24 --> eval_metric= 11.30%
epoch【2】@2024-03-03 09:08:38 --> eval_metric= 11.30%
epoch【3】@2024-03-03 09:08:52 --> eval_metric= 11.30%
epoch【4】@2024-03-03 09:09:06 --> eval_metric= 11.30%


In [None]:
net = create_net()
print(net)

Sequential(
  (conv1): Conv2d(1, 512, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(512, 256, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout2d(p=0.1, inplace=False)
  (adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=256, out_features=128, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=128, out_features=10, bias=True)
)


In [None]:
dl_train,dl_val = create_dataloaders(batch_size=64)

In [None]:
print(dl_train.dataset[0])

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000

NameError: name 'unloader' is not defined

In [None]:
import os
from accelerate.utils import write_basic_config
write_basic_config() # Write a config file
os._exit(0) # Restart the notebook to reload info from the latest config file


In [None]:
# or answer some question to create a config
#!accelerate config

In [None]:
import os,PIL
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn

import torchvision
from torchvision import transforms
import datetime

#======================================================================
# import accelerate
from accelerate import Accelerator
from accelerate.utils import set_seed
#======================================================================


def create_dataloaders(batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])

    ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

    dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False,
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val


def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=512,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("conv2",nn.Conv2d(in_channels=512,out_channels=256,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(256,128))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(128,10))
    return net



def training_loop(epochs = 5,
                  lr = 1e-3,
                  batch_size= 1024,
                  ckpt_path = "checkpoint.pt",
                  mixed_precision="no", #'fp16'
                 ):

    train_dataloader, eval_dataloader = create_dataloaders(batch_size)
    model = create_net()


    optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=25*lr,
                              epochs=epochs, steps_per_epoch=len(train_dataloader))

    #======================================================================
    # initialize accelerator and auto move data/model to accelerator.device
    set_seed(42)
    accelerator = Accelerator(mixed_precision=mixed_precision)
    accelerator.print(f'device {str(accelerator.device)} is used!')
    model, optimizer,lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer,lr_scheduler, train_dataloader, eval_dataloader)
    #======================================================================


    for epoch in range(epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            features,labels = batch
            preds = model(features)
            loss = nn.CrossEntropyLoss()(preds,labels)

            #======================================================================
            #attention here!
            accelerator.backward(loss) #loss.backward()
            #======================================================================

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()



        model.eval()
        accurate = 0
        num_elems = 0

        for _, batch in enumerate(eval_dataloader):
            features,labels = batch
            with torch.no_grad():
                preds = model(features)
            predictions = preds.argmax(dim=-1)

            #======================================================================
            #gather data from multi-gpus (used when in ddp mode)
            predictions = accelerator.gather(predictions)
            labels = accelerator.gather(labels)
            #======================================================================

            accurate_preds =  (predictions==labels)
            num_elems += accurate_preds.shape[0]
            accurate += accurate_preds.long().sum()

        eval_metric = accurate.item() / num_elems

        #======================================================================
        #print logs and save ckpt
        accelerator.wait_for_everyone()
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        accelerator.print(f"epoch【{epoch}】@{nowtime} --> eval_metric= {100 * eval_metric:.2f}%")
        unwrapped_net = accelerator.unwrap_model(model)
        accelerator.save(unwrapped_net.state_dict(),ckpt_path+"_"+str(epoch))
        #======================================================================

# training_loop(epochs = 5,lr = 1e-4,batch_size= 1024,ckpt_path = "checkpoint.pt",
#             mixed_precision="no") #mixed_precision='fp16'

In [None]:
from accelerate import notebook_launcher
#args = (5,1e-4,1024,'checkpoint.pt','no')

args = dict(epochs = 5,
        lr = 1e-4,
        batch_size= 1024,
        ckpt_path = "checkpoint.pt",
        mixed_precision="no").values()
notebook_launcher(training_loop, args, num_processes=2)

Launching training on 2 GPUs.
device cuda:0 is used!
epoch【0】@2024-03-03 09:11:41 --> eval_metric= 90.37%
epoch【1】@2024-03-03 09:11:49 --> eval_metric= 97.27%
epoch【2】@2024-03-03 09:11:56 --> eval_metric= 98.10%
epoch【3】@2024-03-03 09:12:04 --> eval_metric= 98.33%
epoch【4】@2024-03-03 09:12:12 --> eval_metric= 98.43%
