## Setting up parallel training on N GPU/CPU

In [1]:
#!g1.2 #noqa
import random
import sys

import torch
import torch.multiprocessing as mp
from torchvision import transforms

In [2]:
#!g1.2 #noqa
sys.path.append('/home/jupyter/work/resources/')

In [3]:
#!g1.2 #noqa
from processingDataSet import MaskDataset, PreprocessingData, get_not_RGB_pic

In [4]:
#!g1.2 #noqa
#data_path = '/home/jupyter/work/resources/models/datasets/segNet'
data_path = '/home/jupyter/mnt/datasets/Segmentation/Training'
#model_weights_dir = '/home/jupyter/work/resources/figureExtraction/weights/'
#weights_path = model_weights_dir + 'pretrained_encoder_weights_DEFAULT.pt'
random_state = 10

In [5]:
#!g1.2 #noqa
prData = PreprocessingData(0.9)
train_data, val_data = prData.get_data(data_path, random_state, 0.009)

In [6]:
#!g1.2 #noqa
print('Train data size: ', len(train_data), 'Validation data size: ', len(val_data))

In [7]:
#!g1.2 #noqa
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomVerticalFlip(p = 0.5),
    transforms.RandomPerspective(p = 0.5),
    transforms.RandomRotation(random.randint(0, 180))])

In [8]:
#!g1.2 #noqa
train_set = MaskDataset(data_path, train_data, transform)
val_set = MaskDataset(data_path, val_data, transform)

In [9]:
#!g1.2 #noqa
from distLearningFunc import worker

world_size = 2  # Number of GPU
batch_size = 30
seed = random_state
epochs = 1

mp.spawn(worker, args = (world_size, train_set, val_set, batch_size, seed, epochs),
         nprocs = world_size)

In [14]:
#!g1.2 #noqa
'''
    Warning from torch.multiprossesing.spawn:

    If the main process exits abruptly (e.g. because of an incoming signal),
    Python’s multiprocessing sometimes fails to clean up its children.
    It’s a known caveat, so if you’re seeing any resource leaks after interrupting the interpreter,
    it probably means that this has just happened to you.
    https://pytorch.org/docs/stable/multiprocessing.html
'''
# To fix that problem, find the PID of this proсess(es)
# (depends on how many workers were started) and kill them.
!ps -fe | grep multiprocessing.spawn

In [13]:
#!g1.2 #noqa
!kill 4159

In [10]:
#!g1.2
!ls ./models/checkpoints

In [55]:
#!g1.2 #noqa
trained_state_dict = torch.load('/home/jupyter/work/resources/figureExtraction/models/checkpoints/2024_01_19_13_50_48.pt')

In [56]:
#!g1.2
trained_state_dict.keys()

In [12]:
#!g1.2
model_trained = SegNet()
model_trained.load_state_dict(trained_state_dict)

In [14]:
#!g1.2
test_set = MaskDataset(data_path, val_data)

In [32]:
#!g1.2
from torchvision import io

In [46]:
#!g1.2
img_path = '/home/jupyter/work/resources/figureExtraction/datasets/segNet/input/227.jpg'
mask_path = '/home/jupyter/work/resources/figureExtraction/datasets/segNet/Output/227.png'
image = io.read_image(img_path)
mask = io.read_image(mask_path)
norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
to_resized_tensor = transforms.Compose([
            transforms.Resize([224, 224], antialias=True)])
image = norm(to_resized_tensor(image).div(255))
mask = to_resized_tensor(mask).div(255)

In [47]:
#!g1.2
#y = model_trained(test_set[0][0].unsqueeze(0))
y = model_trained(image.unsqueeze(0))

In [17]:
#!g1.2
from matplotlib import pyplot as plt

In [None]:
#!g1.2
y.shape

In [18]:
#!g1.2
import numpy as np
def conv_to_img(tensor: torch.tensor) -> np.array:
    """Convert image to display by pyplot."""
    img = tensor.to('cpu').clone().detach()
    img = img.numpy().squeeze()
    img = img.clip(0, 1)
    return img

In [48]:
#!g1.2
res = conv_to_img(y)

In [20]:
#!g1.2
def conv_to_img1(tensor: torch.tensor) -> np.array:
    """Convert image to display by pyplot."""
    img = tensor.to('cpu').clone().detach()
    img = img.numpy().squeeze()
    img = img.transpose(1, 2, 0)
    img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0, 1)
    return img

In [23]:
#!g1.2
test_set[1][0].shape

In [28]:
#!g1.2
plt.imshow(conv_to_img1(test_set[0][0]))

In [49]:
#!g1.2
plt.imshow(res)

In [None]:
#!g1.2


In [50]:
#!g1.2
plt.imshow(conv_to_img(mask))