In [7]:
import os
import torch
from inference_ram_plus import inference, ram_plus, get_transform
from torch.multiprocessing import Pool
from PIL import Image

def process_image(image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = get_transform(image_size=384)

    # Load model
    model = ram_plus(pretrained='pretrained/ram_plus_swin_large_14m.pth',
                     image_size=384,
                     vit='swin_l')
    model.eval()
    model = model.to(device)

    # Process image
    image = transform(Image.open(image_path)).unsqueeze(0).to(device)
    result = inference(image, model)

    return result[0]

def get_tags_image_list_parallel(image_list):
    # Use multiprocessing Pool to parallelize
    with Pool(processes=5) as pool:
        results = pool.map(process_image, image_list)

    return results

image_list = []

for image in os.listdir('images/stock'):
    image_list.append(os.path.join('images/stock', image))

image_list

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run_inference(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    
    # load a model 
    model = YourModel()
    model.load_state_dict(PATH)
    model.eval()
    model.to(rank)

    # create a dataloader
    dataset = ...
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=4)

    # iterate over the loaded partition and run the model
    for idx, data in enumerate(loader):
            ...

def main():
    world_size = 4
    mp.spawn(run_inference,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

['images/stock/image3.jpg',
 'images/stock/image7.jpeg',
 'images/stock/image5.jpeg',
 'images/stock/image2.jpg',
 'images/stock/image4.jpeg',
 'images/stock/image11.png',
 'images/stock/image9.jpeg',
 'images/stock/image1.jpg',
 'images/stock/image8.jpeg',
 'images/stock/image10.jpeg',
 'images/stock/image6.jpeg']

In [8]:
%timeit get_tags_image_list_parallel(image_list)

KeyboardInterrupt: 

In [3]:
# Write code to time the two functions below.
# Which one is faster? By how much?

# %timeit get_tags_image_list(image_list)

--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
-------------