# 🌱🌿🌾Sorghum PyTorch DDP baseline🚀

This notbook is ***based on*** the [PyTorch DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).  

>If you want to use **TPU** DDP, only a little code you need to change, check this [pytorch-xla](https://github.com/pytorch/xla). 

Thanks to the small jpegs Sorghum images from https://www.kaggle.com/datasets/mithilsalunkhe/small-jpegs-fgvc

If you have any ***question*** about my baseline, please *feel free* to ***make a comment***. I will reply as soon as possible! If you like it, please **upvote**👏👏👏


## Inference result:
>##### The first fold in 5-Fold model, without TTA, get LB **Acc@1: 82.9%**
>##### The 1,2 fold mix up(Avg), without TTA, get LB **Acc@1: 85.5%**
>##### All train image(instead of k-fold) model, without TTA, get LB **Acc@1: 85.6%**
>##### All train image(instead of k-fold) model, use flip and crop TTA, get LB **Acc@1: 86.3%**
>##### The 1,2,3 fold in 5-Fold and all-enrolled model mixup, use flip and crop TTA, get LB **Acc@1: 87.7%**

## main idea
+ *pre-process*
the images(the Visualization part in my notebook will explain the reason)
+ visualize and check the pre-process
+ use the pretrained model(import timm)
+ images augmentation
+ train the model with DDP (and mix precise)
+ [Inference](https://www.kaggle.com/code/leoooo333/lb-0-85-sorghum-higer-accuracy)

## tricks
#### Pre-process
+ **CLAHE** : To pre-process images, use CLAHE in opencv to make image brighter and a higer constrast. [opencv official tutorial here](https://docs.opencv.org/4.x/d5/daf/tutorial_py_histogram_equalization.html)

#### Model
+ **Arcface Loss** : It helps a little in classification accuracy.[README](https://arxiv.org/pdf/1801.07698.pdf) to know more about Arcface Loss! 

+ **Multiple Dropout** : To train a general model, try it.[README](https://arxiv.org/pdf/1905.09788.pdf) to learn about multiple dropout!

+ **Concat global pooling** : use both avgpool and maxpool and concat them in the globalpool layer.

+ **pseudo label** : use efficientnet-noisy student pretrained model, which is small in size but high in accuracy.

#### Training
+ **mix precisison** : efficent and secure! Save the GPU memory and speed up your training.

+ **Ranger** : better than AdamW, most of time

+ **Exponential Warmup** : really help on the models with attention machnism. But you should be awared of **local optimum**, which could be a result of small lr. 

+ **Consine scheduler** : after warmup, descend the lr as cosine function do.

+ **Big lr makes surprise** : Though we take transfer learning method, using big lr(like lr=4e-3) in proper time will speed up, expecially when you find the loss do not change for a long time.

#### Inference
+ **Inference Tutorial is now available! [click here](https://www.kaggle.com/code/leoooo333/sorghum-higer-accuracy/notebook)**

## DDP
**The most important thing**
>To enable DDP, please do not forget to change this notebook to a .py file,and run it in command line.
>You can also download my [.py version](https://drive.google.com/file/d/1McrTZxIxmvn8L72yzYhW1Sx36OgPB5YE/view?usp=sharing) of this baseline.

In [None]:
!pip install seaborn

In [None]:
!pip install timm

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image as Img
from tqdm import tqdm
import os
from sklearn.metrics import accuracy_score
import timm
from tqdm import tqdm  
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

import torchvision
import matplotlib.pyplot as plt
import re
from sklearn.model_selection import train_test_split, StratifiedKFold
import pytorch_lightning as pl
import seaborn as sns
import cv2 as cv
import numpy as np
import torch.nn.functional as F

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import math

# Config

In [None]:
MODEL_NAME = 'tf_efficientnet_b5_ns'
LR = 1e-3
LR_MIN = 1e-5
BATCH_SIZE = 24
IMAGE_SIZE = 900
EPOCH = 20
WARM_UP=5
WEIGHT_DECAY = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
FOLD = 5
LOSS = 'CrossEntropy'
OPTIM = 'adamW'
SCHEDULER = 'Cosine'
USE_AMP = True
INIT = False

'''Last two parameter depends on devices'''
WORLD_SIZE = torch.cuda.device_count() # DistributedDataParallel
NUM_WORKERS = 18

root_in = '../input/small-jpegs-fgvc' #Folder with input (image, lable)
root_out = './' #Folder with output (csv, pth) 
have_index = False # If the breed label have been map to a index
SEED = 42
FOLD = 5

'''ArcFace parameter'''
NUM_CLASSES = 100
EMBEDDING_SIZE = 1024
S, M = 30.0, 0.5 # S:consine scale in arcloss. M:arg penalty
EASY_MERGING, LS_EPS = False, 0.0

# Utils


### train and test utils

In [None]:
class Accumulator():
    '''A counter util, which count the float value of the input'''
    def __init__(self, nums):
        self.metric = list(torch.zeros((nums,)).numpy())
        
    def __getitem__(self, index):
        return self.metric[index]
    
    def add(self, *args):
        for i, item in enumerate(args):
            self.metric[i] += float(item)

In [None]:
def accuracy(y_hat, y):
    '''used to count the right type'''
    y_hat = y_hat.exp().argmax(dim=1)
    y_hat.reshape((-1))
    y.reshape((-1))
    return accuracy_score(y.cpu().numpy(), y_hat.cpu().numpy(), normalize=False)

In [None]:
def evaluate_accuracy(net, data_iter, device=None):
    '''Evalue the valid dataset'''
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y.to(device)
            with torch.cuda.amp.autocast(enabled=True):
                metric.add(accuracy(net(X, y), y), y.numel())
    return metric[0] / metric[1]

In [None]:
def predict_test(net, test_iter, device=None):
    '''Inference'''
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    y = []
    net.to(device)
    #softmax = nn.Softmax(dim=1)
    with torch.no_grad():
        for X in test_iter:
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            with torch.cuda.amp.autocast(enabled=True):
                #y += softmax(net(X).cpu())
                _, indice = net(X).exp().sort(dim=1,descending=True)
                y += indice[:, 0:5].cpu()

    return list(Y.numpy() for Y in y)

In [None]:
def seed_everything(seed):
    pl.utilities.seed.seed_everything(seed)
    return seed

### Model utils


In [None]:
def freeze_pretrained_layers(model):
    '''Freeze all layers except the last layer(fc or classifier)'''
    for param in model.parameters():
            param.requires_grad = False
    #nn.init.xavier_normal_(model.fc.weight)
    #nn.init.zeros_(model.fc.bias)
    model.classifier.weight.requires_grad = True
    model.classifier.bias.requires_grad = True

In [None]:
def debarcle_layers(model, num_debarcle=0, db_all=False):
    '''Debarcle From the last [-1]layer to the [-num_debarcle] layers, 
    approximately(for there is Conv2d which has only weight parameter),
    if db_all == True, debarcle all layers'''
    num_debarcle *= 2
    param_debarcle = param_name[-num_debarcle:]
    if param_debarcle[0].split('.')[-1] == 'bias':
        param_debarcle = param_name[-(num_debarcle + 1):]
    if db_all:
        for name, param in model.named_parameters():
            param.requires_grad = True
    else:
        for name, param in model.named_parameters():
            param.requires_grad = True if name in param_debarcle else False

### DDP utils

In [None]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

In [None]:
def cleanup():
    dist.destroy_process_group()

In [None]:
def prepare(dataset, rank, world_size, batch_size=BATCH_SIZE, pin_memory=False, num_workers=0):
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, sampler=sampler)
    return dataloader

### Build & Check label index

In [None]:
def data_pre_access(file, output):
    '''transfer train label into index'''
    labels = pd.read_csv(file, index_col='image')
    labels_map = dict()
    labels['label_index'] = torch.zeros((labels.shape[0])).type(torch.int32).numpy()
    for i, label in enumerate(labels.cultivar.unique()):
        labels_map[i] = label
        labels.loc[labels.cultivar == label, 'label_index'] = i
    labels.to_csv(output)
    
    return labels_map

In [None]:
if have_index:
    labels_map = {}
    train_df = pd.read_csv(os.path.join(root_out,'labels_index.csv'))
    def label_f(m):
        labels_map[int(m.label_index)] = m.cultivar
    train_df.apply(label_f,axis=1)
else:
    labels_map = data_pre_access(os.path.join(root_in,'train_cultivar_mapping.csv'), output=os.path.join(root_out,'labels_index.csv'))
    train_df = train_df = pd.read_csv(os.path.join(root_out,'labels_index.csv'))

In [None]:
check_sum = 0
for key, val in tqdm(labels_map.items()):
    train_df[train_df.label_index == key].cultivar.unique() == val
    check_sum += 1

In [None]:
check_sum, check_sum==len(labels_map)

# Dataset

In [None]:
class Sorghum_Train_Dataset(Dataset):
    '''Train Dataset'''
    def __init__(self, img_path_csv='', df=None, transform=None):
        if df is not None:
            self.df = df
        else:
            self.df = pd.read_csv(img_path_csv)
        self.transform = transform
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        img = Img.open(os.path.join(root_in, 'train', self.df.iloc[index, 0]))
        label_index = self.df.iloc[index, 4]
        if self.transform is not None:
            img = self.transform(img)
        return img, label_index

In [None]:
class Sorghum_Test_Dataset(Sorghum_Train_Dataset):
    '''Test Dataset'''
    def __getitem__(self, index):
        img = Image.open(os.path.join(root_in, 'test', self.df.iloc[index, 0]))
        if self.transform:
            img = self.transform(img)
        return img

# Look Inside Sorghum breed

In [None]:
labels = pd.read_csv(os.path.join(root_out, 'labels_index.csv'))

In [None]:
sns.catplot(y="cultivar", kind="count", data=labels, height=20)

In [None]:
def first(ls):
    for i, flag in enumerate(ls):
        if flag == True:
            return i
        
samples = []
for cultivar in labels.cultivar.unique():
    img, label = labels.iloc[first(labels.cultivar == cultivar), [0, 1]]
    samples += [(img, label)]

f, axarr = plt.subplots(3,3,figsize=(50,50))
for i in range(3):
    for j in range(3):
        axarr[i, j].imshow(Img.open(os.path.join(root_in, 'train', samples[3*i + j][0])))
        axarr[i, j].set_title(samples[3*i + j][1])

# CLAHE

In [None]:
def CLAHE_Convert(origin_input):
    clahe = cv.createCLAHE(clipLimit=40, tileGridSize=(10,10))
    t = np.asarray(origin_input)
    t = cv.cvtColor(t, cv.COLOR_BGR2HSV)
    t[:,:,-1] = clahe.apply(t[:,:,-1])
    t = cv.cvtColor(t, cv.COLOR_HSV2BGR)
    t = Img.fromarray(t)
    return t

In [None]:
p0 = Img.open(os.path.join(root_in, 'train', samples[10*9 + 3][0]))
g0 = transforms.Grayscale(num_output_channels=1)(p0)
t0 = CLAHE_Convert(p0)
n0 = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])(transforms.ToTensor()(t0))
nn0 = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])(transforms.ToTensor()(p0))

p1 = Img.open(os.path.join(root_in, 'train', samples[10*7 + 9][0]))
g1 = transforms.Grayscale(num_output_channels=1)(p1)
t1 = CLAHE_Convert(p1)
n1 = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])(transforms.ToTensor()(t1))
nn1 = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])(transforms.ToTensor()(p1))

f, axarr = plt.subplots(2,5,figsize=(25,10))
axarr[0,0].imshow(g0, 'gray')
axarr[0,0].set_title("GRAY",fontsize=26)
axarr[0,1].imshow(p0)
axarr[0,1].set_title("ORIGINAL",fontsize=26)
axarr[0,2].imshow(t0)
axarr[0,2].set_title("CLAHE",fontsize=26)
axarr[0,3].imshow(n0.permute(1, 2, 0))
axarr[0,3].set_title("Normalize(CLAHE)",fontsize=26)
axarr[0,4].imshow(nn0.permute(1, 2, 0))
axarr[0,4].set_title("Normalize(ORIGINAL)",fontsize=26)

axarr[1,0].imshow(g1, 'gray')
axarr[1,1].imshow(p1)
axarr[1,2].imshow(t1)
axarr[1,3].imshow(n1.permute(1, 2, 0))
axarr[1,4].imshow(nn1.permute(1, 2, 0))

# Model

In [None]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta + m)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        s: float,
        m: float,
        easy_margin: bool,
        ls_eps: float,
        rank
    ):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        self.rank = rank

    def forward(self, input: torch.Tensor, label: torch.Tensor, device = 'cuda') -> torch.Tensor:
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        # Enable 16 bit precision
        cosine = cosine.to(torch.float32)

        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=self.rank)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [None]:
class SorghumModel(nn.Module):
    def __init__(self, model_name, embedding_size, map_location, k_fold, rank, pretrained=True):
        super(SorghumModel, self).__init__()       
        
        #model_effecient_b6 = timm.create_model(model_name, pretrained=pretrained, num_classes=NUM_CLASSES)
        #global param_name
        #param_name = [name for name,_ in model_effecient_b6.named_parameters()] # All parameters name
        #del model_effecient_b6
            
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=NUM_CLASSES)
        
        #freeze_pretrained_layers(self.model)
        #debarcle_layers(self.model, db_all=True) # Debarcle all layers()
        
        print('load Start!!!')
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.pooling = self.model.global_pool
        self.model.global_pool = nn.Identity()
        #self.pooling = GeM()
        self.rank = rank
        self.multiple_dropout = [nn.Dropout(0.25) for i in range(8)]
        self.embedding = nn.Linear(in_features * 2, embedding_size)
        self.fc = ArcMarginProduct(embedding_size, 
                                   NUM_CLASSES,
                                   S, 
                                   M, 
                                   EASY_MERGING, 
                                   LS_EPS,
                                  self.rank)

    def forward(self, images, labels):
        features = self.model(images)
        pooled_features_avg = self.pooling(features).flatten(1)
        pooled_features_max = nn.AdaptiveMaxPool2d((1,1))(features).flatten(1)
        pooled_features = torch.cat((pooled_features_avg, pooled_features_max), dim=1)
        pooled_features_dropout = torch.zeros((pooled_features.shape),device=self.rank)
        for i in range(8):
            pooled_features_dropout += self.multiple_dropout[i](pooled_features)
        pooled_features_dropout /= 8
        embedding = self.embedding(pooled_features_dropout)
        #pooled_features = nn.Dropout(0.5)(pooled_features)
        #embedding = self.embedding(pooled_features)
        output = self.fc(embedding, labels)
        return output
    
    def extract(self, images):
        features = self.model(images)
        pooled_features_avg = self.pooling(features).flatten(1)
        pooled_features_max = nn.AdaptiveMaxPool2d((1,1))(features).flatten(1)
        pooled_features = torch.cat((pooled_features_avg, pooled_features_max), dim=1)
        embedding = self.embedding(pooled_features)
        return embedding

# Train 

In [None]:
def train_model_dist(rank,  world_size, model_name, num_epochs, loss_name, lr, lr_min, weight_decay, optim, use_amp, init, scheduler_type, warm_up):
    '''Parameters:
        lr(float): the begining learning rate
        lr_min(float): min learning rate
        optim(String): the optimizer type
        use_amp(Boolean): Use mixed precision on GPU or not
        init(Boolean): Need initial the layers parameter or not
        scheduler_type(String): Learning rate scheduler
        
       Detail:
        The train process will save the model's parameter every 10 epochs.
        Every epoch, scheduler update once, and evaluate the train_dataset's accuracy and print it to std 5 times, 
        print the valid_dataset's accuracy once.
        If the valid accuracy >= 0.75, save the model's parameters as well.
    '''
    seed_everything(SEED)
    
    if have_index:
        labels_map = {}
        train_df = pd.read_csv(os.path.join(root_out,'labels_index.csv'))
        def label_f(m):
            labels_map[int(m.label_index)] = m.cultivar
        train_df.apply(label_f,axis=1)
    else:
        labels_map = data_pre_access(os.path.join(root_in,'train_cultivar_mapping.csv'), output=os.path.join(root_out,'labels_index.csv'))
    
    train_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.1),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomApply(transforms=
                      [transforms.RandomResizedCrop(size=IMAGE_SIZE, scale=(0.3,0.4), 
                                                    ratio=(1/3,3),interpolation=
                                                    transforms.InterpolationMode.BICUBIC)],p=0.2),
        transforms.ToTensor(),
        # Normalize to fit pretrained model
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

    val_test_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        # Normalize to fit pretrained model
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

    sfolder = StratifiedKFold(n_splits=FOLD,random_state=SEED,shuffle=True)
    train_folds = []
    val_folds = []
    for train_idx, val_idx in sfolder.split(train_df.image, train_df.label_index):
        train_folds.append(train_idx)
        val_folds.append(val_idx)
        print(len(train_folds), len(val_folds))
        

    
    def init_xavier(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_normal_(m.weight)
            
    

    if rank == 0:
        os.makedirs('/root/tf-logs/' + model_name, exist_ok=True)
        writer = SummaryWriter(log_dir='/root/tf-logs/' + model_name)
    
    
    print('training on', rank) 
    setup(rank, world_size)
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    
    for k_fold in range(FOLD):
        if rank == 0:  
            Value_train_l = list()
            Value_train_acc = list()
            Value_test_acc = list()
            Time = list()
            print('\n ********** Fold %d **********\n'%k_fold)
            sub_fold = model_name + '_F_' + str(k_fold)
            os.makedirs(os.path.join(root_out, sub_fold), exist_ok=True)

        train_dataset = Sorghum_Train_Dataset(df=train_df.iloc[train_folds[k_fold]],
                                            transform=train_transform)

        val_dataset = Sorghum_Train_Dataset(df=train_df.iloc[val_folds[k_fold]],
                                          transform=val_test_transform)
        
        net = SorghumModel(model_name, EMBEDDING_SIZE, map_location, k_fold, rank)

        net = net.to(rank)
        net = DDP(net, device_ids=[rank], output_device=rank)

        
        #net.load_state_dict(torch.load(os.path.join(root_out, 
        #                                            model_name+ '_F_' + str(k_fold),
        #                                            'Sorghum3.params'),
        #                               map_location=map_location))
        print('load Finish!!!')

        train_loader = prepare(train_dataset, rank, world_size, num_workers=NUM_WORKERS)
        val_loader = prepare(val_dataset, rank, world_size, num_workers=NUM_WORKERS)
     
        if optim == 'sgd':
            optimizer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad), lr=lr, weight_decay=weight_decay)
        elif optim == 'adam':
            optimizer = torch.optim.Adam((param for param in net.parameters() if param.requires_grad), lr=lr, weight_decay=weight_decay)
        elif optim =='adamW':
            optimizer = torch.optim.AdamW((param for param in net.parameters() if param.requires_grad), lr=lr, weight_decay=weight_decay)
        elif optim == 'ranger':
            optimizer = Ranger((param for param in net.parameters() if param.requires_grad), lr=lr, weight_decay=weight_decay)

        scaler = torch.cuda.amp.GradScaler(enabled=use_amp) # mixed_precison

        if scheduler_type == 'Cosine':
            scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr_min)
                
        def warm_up_scheduler(epoch):
            return (1 / 2) ** (warm_up-epoch)
        
        if loss_name == 'CrossEntropy':
            loss = nn.CrossEntropyLoss(label_smoothing=0.1)
            
        if init:
            net.apply(init_xavier)
            
        num_batches = len(train_loader)
        best_accuracy = 0
        
        scheduler = LambdaLR(optimizer, lr_lambda=warm_up_scheduler)
        
        for epoch in range(num_epochs):
            if epoch == warm_up:
                if scheduler_type == 'Cosine':
                    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs-warm_up, eta_min=lr_min)
                
            net.train()
            train_loader.sampler.set_epoch(epoch)
            val_loader.sampler.set_epoch(epoch)

            metric = Accumulator(3)

            for i, (X, y) in enumerate(tqdm(train_loader)):
                X = X.to(rank)
                y = y.to(rank)
                with torch.cuda.amp.autocast(enabled=use_amp):
                    y_hat = net(X, y)
                    l = loss(y_hat, y)

                scaler.scale(l).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

                with torch.no_grad():
                    metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])

                if rank == 0:    
                    train_l = metric[0] / metric[2]
                    train_acc = metric[1] / metric[2]
                    if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                        print(epoch + (i + 1) / num_batches,
                                     'train_l train_acc\t',(train_l, train_acc,None))
                        writer.add_scalars('Loss/Accuracy/train/Fold-' + str(k_fold), 
                                           {'train_accuracy':np.array(train_acc), 'train_loss':np.array(train_l)}, 
                                           5 * np.array(epoch + (i + 1) / num_batches))
                        Value_train_l.append(train_l)
                        Value_train_acc.append(train_acc)
                        Value_test_acc.append(None)
                        Time.append(epoch + (i + 1) / num_batches)

            scheduler.step()
            
            test_acc = evaluate_accuracy(net, val_loader, device=rank)
            
            if rank == 0: 
                print('lr = ', optimizer.param_groups[0]['lr'])
                print(epoch + 1,'test_acc\t', (None, None, test_acc))
                writer.add_scalars('Loss/Accuracy/test/Fold-' + str(k_fold), 
                                   {'val_accuracy':np.array(test_acc)}, 
                                   5 * np.array(epoch + 1))
                Value_train_l.append(None)
                Value_train_acc.append(None)
                Value_test_acc.append(test_acc)
                Time.append(epoch + 1)

                if epoch % 10 == 0 or test_acc >= best_accuracy:
                    best_accuracy = test_acc
                    torch.save(net.state_dict(),os.path.join(root_out, sub_fold, 'Sorghum' + str(epoch + 1) + '.params'))
                record_data = pd.DataFrame(zip(Value_train_l, Value_train_acc, Value_test_acc, Time))    
                record_data.to_csv(os.path.join(root_out, sub_fold, 'Record_Sorghum.csv')) 

        if rank == 0:
            torch.save(net.state_dict(),os.path.join(root_out, sub_fold, 'Sorghum.params'))

            print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
                  f'test acc {test_acc:.3f}')
            print(f'on {str(rank)}')
        torch.cuda.empty_cache()
    cleanup()
    writer.close()

# Train with DDP

In [None]:
if __name__ == '__main__':
    print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')    
    mp.spawn(
        train_model_dist,
        args=(WORLD_SIZE, MODEL_NAME, EPOCH, LOSS, LR, LR_MIN, WEIGHT_DECAY, OPTIM, USE_AMP, INIT, SCHEDULER, WARM_UP),
        nprocs=WORLD_SIZE
    )

### Convert notebook to .py file 

In [None]:
!jupyter nbconvert --to script Sorghum_DDP.ipynb

### Run your .py file in command line

In [None]:
'''Do remember run in command line, instead of in jupyter notebook
   And delete the convert code above in .py file'''

> python Sorghum_DDP.py