In [21]:
from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity, schedule
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist

from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch
import time 
import os 

class CustomDataset(Dataset):
    def __init__(self,X,Y):
        self.X = X
        self.Y = Y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return self.X[idx], self.Y[idx]

class MLP(nn.Module):
    def __init__(self,c_in = 4,h_dim1 = 16, h_dim2 = 16, c_out = 1,device = 'cpu'):
        super(MLP,self).__init__()
        self.linear1 = nn.Linear(c_in,h_dim1)
        self.linear2 = nn.Linear(h_dim1,h_dim2)
        self.linear3 = nn.Linear(h_dim2,c_out)
        self.relu = nn.ReLU()
        self.device = device
    def forward(self,x):
        return(self.linear3(self.relu(self.linear2(self.relu(self.linear1(x))))))
    
def load_profile_dataloader(dataset,B,num_workers,persistent_workers,pin_memory,prefetch_factor,drop_last,dataparallel=False,prof = False):
    if prof:
        activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] if torch.cuda.is_available() else [ProfilerActivity.CPU]
        prof =  profile(activities=activities,
                        schedule=schedule(wait=1, warmup=1, active=12, repeat=1),
                        on_trace_ready=tensorboard_trace_handler('./profiler/trial_profiler'),
                        profile_memory=True,
                        record_shapes=False, 
                        with_stack=False,
                        with_flops=False
                        )
    else:
        prof = None
    
    if dataparallel:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset,
                                                                num_replicas=2,
                                                                rank=0,
                                                                shuffle=True)
    else: 
        sampler = None
        
    

    dataloader = DataLoader(dataset,batch_size=B,shuffle =False if dataparallel else True ,
                            num_workers=num_workers, 
                            persistent_workers= False if num_workers == 0 else persistent_workers,
                            pin_memory=pin_memory,
                            prefetch_factor=None if num_workers==0 else prefetch_factor, #2,3,4,5...
                            drop_last=drop_last,
                           sampler = sampler 
                           )
    
    return(prof,dataloader,sampler)


def load_model_loss_opt(c_in,h_dim1,h_dim2,c_out,device,dataparallel = False):
    loss = nn.MSELoss()
    model = MLP(c_in,h_dim1, h_dim2, c_out,device).to(device)
    if dataparallel:
        dist.init_process_group(backend='nccl',# init_method='env://',
                                world_size=2, 
                                rank=0)
        model = DistributedDataParallel(model,device_ids = [0])
    optimizer = torch.optim.SGD(model.parameters(), 1e-3)

    return(loss,model,optimizer)


def training(model,optimizer,loss,prof,epochs,dataloader,device,scaler):
    t_epochs,t_batchs, t_coms, t_forwards, t_backwards,total_time = 0,0,0,0,0,0
    for epoch in range(epochs):
        if sampler is not None:
            sampler.set_epoch(epoch)
        epoch1 = time.time()
        t_epoch = time.time()
        for x,y in dataloader:
            t_batch = time.time()
            t_com = time.time()
            x,y = x.to(device),y.to(device)
            t_coms += time.time()-t_com

            t_forward = time.time()
            pred = model(x)
            t_forwards += time.time()-t_forward
            
            if len(y.size()) != len(pred.size()): 
                pred = pred.squeeze()

            t_backward = time.time()
            l = loss(pred,y)
            t_backwards += time.time() - t_backward
            if scaler is not None:
                scaler.scale(l).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.zero_grad()
                l.backward()
                optimizer.step()

            t_batchs += time.time()-t_batch
        t_epochs += time.time()-t_epoch
        if epoch == 0:
            t_epochs = 0
            epochs1 = time.time()-epoch1
        if prof:
            prof.step()
        return(t_epochs,epochs1,t_coms,t_forwards,t_backwards)

def train_model(model,optimizer,loss,prof,epochs,dataloader,sampler,device):
    total_t = time.time()
    if sampler is not None:
        scaler = GradScaler()
    else:
        scaler = None
    if prof is not None:
        with prof:
            (t_epochs,epochs1,t_coms,t_forwards,t_backwards) = training(model,optimizer,loss,prof,epochs,dataloader,device,scaler)
    else:
        (t_epochs,epochs1,t_coms,t_forwards,t_backwards) = training(model,optimizer,loss,prof,epochs,dataloader,device,scaler)
        
    total_time = time.time() - total_t

    print(f"Total time: {total_time} \nTime per epoch: {(t_epochs)/(epochs-1)} \
        \nTime first epoch: {(epochs1)} \nTime Communication: {t_coms}\
            \nTime forwards: {t_forwards} \nTime Backward: {t_backwards}\
            ")


## Test avec different worker, sur model moyen (petit), input shape proche des miens : 

In [12]:
B = 8
T = 6000
L = 8
N = 40
epochs = 300
device = 'cuda' if torch.cuda.is_available() else 'cpu'
c_in,h_dim1,h_dim2,c_out = L, 64, 64, 1

persistent_workers = False
pin_memory = False
prefetch_factor = None
drop_last = False

X,Y=  torch.randn(T,N,L),torch.randn(T,N)
#inputs = list(zip(X,Y))
inputs = CustomDataset(X,Y)

for num_workers in [0,1,2,4,6,8]:
    print('\nNum workers:',num_workers)
    (prof,dataloader,sampler) = load_profile_dataloader(inputs,B,num_workers,persistent_workers,pin_memory,prefetch_factor,drop_last)
    (loss,model,optimizer) = load_model_loss_opt(c_in,h_dim1,h_dim2,c_out,device)
    train_model(model,optimizer,loss,prof,epochs,dataloader,sampler,device)


Num workers: 0
Total time: 0.8179500102996826 
Time per epoch: 0.0         
Time first epoch: 0.8179304599761963 
Time Communication: 0.041181087493896484            
Time forwards: 0.15056371688842773 
Time Backward: 0.04733085632324219            

Num workers: 1
Total time: 1.6966233253479004 
Time per epoch: 0.0         
Time first epoch: 1.6965746879577637 
Time Communication: 0.062432289123535156            
Time forwards: 0.1364271640777588 
Time Backward: 0.045671701431274414            

Num workers: 2
Total time: 1.8864109516143799 
Time per epoch: 0.0         
Time first epoch: 1.8863301277160645 
Time Communication: 0.12665390968322754            
Time forwards: 0.1561121940612793 
Time Backward: 0.04915332794189453            

Num workers: 4
Total time: 1.9492802619934082 
Time per epoch: 0.0         
Time first epoch: 1.9491872787475586 
Time Communication: 0.11153483390808105            
Time forwards: 0.15110492706298828 
Time Backward: 0.05306577682495117            

## Essaie avec toute les données chargées initialement en mémoire : 
Ici, impossible de les charger en mémoire en amont pour du num_worker > 0. 

In [24]:
B = 8
T = 6000
L = 8
N = 40
epochs = 300

persistent_workers = False
pin_memory = False
prefetch_factor = None
drop_last = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'


c_in,h_dim1,h_dim2,c_out = L, 64, 64, 1

X,Y=  torch.randn(T,N,L).to(device),torch.randn(T,N,1).to(device)
inputs = CustomDataset(X,Y)

for num_workers in [0]:
    print('\nNum workers:',num_workers)
    (prof,dataloader,sampler) = load_profile_dataloader(inputs,B,num_workers,persistent_workers,pin_memory,prefetch_factor,drop_last)
    (loss,model,optimizer) = load_model_loss_opt(c_in,h_dim1,h_dim2,c_out,device)
    train_model(model,optimizer,loss,prof,epochs,dataloader,sampler,device)


Num workers: 0
Total time: 0.634476900100708 
Time per epoch: 0.0         
Time first epoch: 0.634458065032959 
Time Communication: 0.0035734176635742188            
Time forwards: 0.11711645126342773 
Time Backward: 0.038048505783081055            


## Ajout de : persistent_worker, pin_memory, prefetch_factor = 2 

In [22]:
persistent_workers = True
pin_memory = True
drop_last = True

X,Y=  torch.randn(T,N,L),torch.randn(T,N,1)
inputs = CustomDataset(X,Y)

for num_workers in [0,1]:
    for prefetch_factor in [None,2,4,8]:
        print('\nNum workers:',num_workers,'and prefetch_factor: ',prefetch_factor)
        (prof,dataloader,sampler) = load_profile_dataloader(inputs,B,num_workers,persistent_workers,pin_memory,prefetch_factor,drop_last)
        (loss,model,optimizer) = load_model_loss_opt(c_in,h_dim1,h_dim2,c_out,device)
        train_model(model,optimizer,loss,prof,epochs,dataloader,sampler,device)


Num workers: 0 and prefetch_factor:  None
Total time: 0.7066452503204346 
Time per epoch: 0.0         
Time first epoch: 0.7066271305084229 
Time Communication: 0.03134489059448242            
Time forwards: 0.11990141868591309 
Time Backward: 0.03914356231689453            

Num workers: 0 and prefetch_factor:  2
Total time: 0.6967222690582275 
Time per epoch: 0.0         
Time first epoch: 0.6967051029205322 
Time Communication: 0.031247854232788086            
Time forwards: 0.1203000545501709 
Time Backward: 0.03901362419128418            

Num workers: 0 and prefetch_factor:  4
Total time: 0.6892411708831787 
Time per epoch: 0.0         
Time first epoch: 0.6892240047454834 
Time Communication: 0.031081199645996094            
Time forwards: 0.11827659606933594 
Time Backward: 0.0383143424987793            

Num workers: 0 and prefetch_factor:  8
Total time: 0.6936221122741699 
Time per epoch: 0.0         
Time first epoch: 0.6936051845550537 
Time Communication: 0.03083586692810

## Choice of best config: 

In [71]:
num_workers= 1 #4
persistent_workers = True
pin_memory = True
prefetch_factor = 4
drop_last = False

## Trial with Dataparallel : 

In [None]:
import os
import socket

def find_free_port():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(('', 0))
    addr, port = s.getsockname()
    s.close()
    return port

free_port = find_free_port()
os.environ['MASTER_ADDR'] = '137.121.170.69'
os.environ['MASTER_PORT'] = 8888 #Ne fonctionne pas #str(free_port) 
print(f"Using port {free_port} for MASTER_PORT")

dataparallel = True

(prof,dataloader,sampler) = load_profile_dataloader(inputs,B,num_workers,persistent_workers,pin_memory,prefetch_factor,drop_last,dataparallel)
(loss,model,optimizer) = load_model_loss_opt(c_in,h_dim1,h_dim2,c_out,device,dataparallel)
train_model(model,optimizer,loss,prof,epochs,dataloader,sampler,device)