In [43]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread, imshow, imsave
import cv2
from scipy import ndimage

In [50]:
class Image:
    image=None
    file_path=""
    color_code=0
    #BGR for 0 , RGB for 1 and grayscale for 2
    def __init__(self, image_path, image=None, color_code=0):
        self.file_path = image_path
        self.load()
        self.color_code = color_code

    def load(self):
        self.image = imread(self.file_path)

    def display(self,title="Image"):
        plt.figure(figsize=[10,10])
        channels=len(self.image.shape)
        if channels<3:
            plt.imshow(self.image,cmap='Greys_r');
        else:
            plt.imshow(self.image);      
        plt.title(title);plt.axis("off");        

    def resize(self,scale_percent):
        width = int(self.image.shape[1] * scale_percent / 100)
        height = int(self.image.shape[0] * scale_percent / 100)
        dim = (width, height)
        resized_img = cv2.resize(self.image, dim, interpolation = cv2.INTER_AREA)
        self.image=resized_img
        return self
    
    def grayscale(self):
        channels=len(self.image.shape)
        if channels<3:
            return img
        grayscale_image = cv2.cvtColor(self.image, cv2.COLOR_BGR2GRAY)
        self.image=grayscale_image
        return self

    def smooth(self,kernel_size):
        smooth_image = cv2.GaussianBlur(self.image,(kernel_size,kernel_size),0)
        self.image=smooth_image
        return self

    def dilate(self,kernel_size,iteration):
        kernel = np.ones((kernel_size,kernel_size),np.uint8)
        self.image = cv2.dilate(self.image,kernel,iterations=iteration)
        return self
    
    def erode(self,kernel_size,iteration):
        kernel = np.ones((kernel_size,kernel_size),np.uint8)
        self.image = cv2.erode(self.image,kernel,iterations=iteration)
        return self
        
    def plot_histogram(self):
        colors = ("red", "green", "blue")
        channel_ids = (0, 1, 2)

        plt.figure()
        plt.xlim([0, 256])
        for channel_id, c in zip(channel_ids, colors):
            histogram, bin_edges = np.histogram(
                self.image[:, :, channel_id], bins=256, range=(0, 256)
            )
            plt.plot(bin_edges[0:-1], histogram, color=c)

        plt.title("Color Histogram")
        plt.xlabel("Color value")
        plt.ylabel("Pixel count")

    def adjust_brightness(self,value):
        hsv = cv2.cvtColor(self.image, cv2.COLOR_BGR2HSV) #convert it to hsv
        for x in range(0, len(hsv)):
            for y in range(0, len(hsv[0])):
                hsv[x, y][2] += value
        self.image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        return self
    
    def flip(self,flip_code):
        self.image = cv2.flip(self.image,flip_code)
        return self
        
    def rotate(self,angle):
        self.image = ndimage.rotate(self.image,angle)
        return self
    
    def translate_horizontal(self,shift):
        M = np.float32([[1,0,0],[0,1,0]])
        M[0][2]=shift
        rows,cols,channels = self.image.shape
        self.image = cv2.warpAffine(self.image,M,(cols,rows))
        return self

    def translate_vertical(self,shift):
        M = np.float32([[1,0,0],[0,1,0]])
        M[1][2]=shift
        rows,cols,channels = self.image.shape
        self.image = cv2.warpAffine(self.image,M,(cols,rows))
        return self
    
    def getFilepath(self):
        return self.file_path

    def getColorcode(self):
        return self.color_code