In [62]:
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from IPython.display import HTML

In [67]:
%%writefile shift.py

import numpy as np
import tracemalloc

from PIL import Image
from mpi4py import MPI

tracemalloc.start()

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
start = MPI.Wtime()
H = 387
W = 580

N_per_pank = H // size

data = None

if rank == 0:
    data = np.array(Image.open('pic.jpg'))
    N_per_pank =  H - (N_per_pank * (size - 1))


received = np.empty((N_per_pank, W, 3), dtype=np.uint8)
sendcounts = np.array(comm.gather(received.size, 0))

def roll(arr):
    res = np.empty_like(arr)
    w = arr.shape[1]
    for i in range(w):
        res[:, (i + 1) % w] = arr[:, i]
    return res
    
for i in range(W):
    comm.Scatterv((data, sendcounts), received, root=0)
    received_rolled = roll(received)
    comm.Gatherv(received_rolled, (data, sendcounts), root=0)
    if rank == 0:
        Image.fromarray(data).save(f'images\{i:0>5d}.jpg')


current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

peak = np.array(comm.gather(peak, 0))
if rank ==0:
    end = MPI.Wtime()
    print((end-start))

Overwriting shift.py


In [None]:
!mpiexec -n 1 python shift.py

In [63]:
files = [f for f in os.listdir('images') if os.path.isfile(os.path.join('images', f))]
files.sort()
images = [Image.open('images/' + f) for f in files]

imageio.mimsave('shift.gif', images, fps=200)
HTML('<img src="shift.gif">')

In [None]:
exe_time = []
N_processes = np.arange(1,11)
for i in N_processes:
    add = %timeit !mpiexec -n {i} python shift.py
    exe_time.append(float(add[0]))
    print(f"Process {i} with time {add[0]}")

In [None]:
#Plot the speedup vs number of processors 
plt.plot(N_processes,  exe_time[0] / np.array(exe_time))
plt.title('Speedup vs number of processes')
plt.xlabel('Number of processes')
plt.ylabel('Speedup')
plt.grid(True)