In [None]:
import os
from PIL import Image
from torchvision import transforms
import torchvision
import torch
import numpy as np

SEGMENTATION_MODEL = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='DEFAULT')
SEGMENTATION_MODEL.eval()

OUTPUT_IMAGE_SIZE= 224 #224*224

def segment_car(img): 
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    img_tensor = transform(img)

    with torch.no_grad():
        predictions = SEGMENTATION_MODEL([img_tensor])[0]

    boxes = predictions['boxes'].cpu().numpy()
    labels = predictions['labels'].cpu().numpy()
    scores = predictions['scores'].cpu().numpy()
    masks = predictions['masks'].cpu().numpy()

    biggest_box_car = np.argmax([ (boxes[i][2]-boxes[i][0])*(boxes[i][3]-boxes[i][1]) if labels[i] == 3 else 0 for i in range(len(labels))])
    box = boxes[biggest_box_car].astype(int)
    mask = masks[biggest_box_car, 0] > 0.5
    x0, y0, x1, y1 = box
    cropped = img.crop((x0, y0, x1, y1))
    masked_np = mask[:,:,None] * img
    masked_np = masked_np.astype(np.uint8)

    def crop_zero_borders():
        zero_top_index = np.where(np.cumsum(np.sum(masked_np,axis=(1,2)))==0)[0]
        if(len(zero_top_index)>0):
            zero_top_index= max(0,zero_top_index[-1]-5)
        else:
            zero_top_index=0
        zero_bottom_index = np.where(np.cumsum(np.sum(masked_np[::-1],axis=(1,2)))==0)[0]
        if(len(zero_bottom_index)>0):
            zero_bottom_index = min(masked_np.shape[0],masked_np.shape[0]-zero_bottom_index[-1]+5)
        else:
            zero_bottom_index=masked_np.shape[0]
            
        zero_left_index = np.where(np.cumsum(np.sum(masked_np,axis=(0,2)))==0)[0]
        if(len(zero_left_index)>0):
            zero_left_index= max(0,zero_left_index[-1]-5)
        else:
            zero_left_index=0
        zero_right_index = np.where(np.cumsum(np.sum(masked_np[:,::-1],axis=(0,2)))==0)[0]
        if(len(zero_right_index)>0):
            zero_right_index = min(masked_np.shape[1],masked_np.shape[1]-zero_right_index[-1]+5)
        else:
            zero_right_index=masked_np.shape[1]
        return zero_top_index, zero_bottom_index, zero_left_index, zero_right_index
    
    cropped_indices = crop_zero_borders()
    masked_np = masked_np[cropped_indices[0]:cropped_indices[1],cropped_indices[2]:cropped_indices[3],:]
    masked_pil = Image.fromarray(masked_np)
    return masked_pil

def transform_image(img):
    img = img.convert("RGB")
    img = segment_car(img)
    width, height = img.size

    if(width<OUTPUT_IMAGE_SIZE and height<OUTPUT_IMAGE_SIZE):
        padding = (0, 0, OUTPUT_IMAGE_SIZE - width, OUTPUT_IMAGE_SIZE - height)
        transform = transforms.Compose([
            transforms.Pad(padding=padding, fill=0),
            transforms.ToTensor(),
            transforms.ToPILImage()
        ])
        img = transform(img)
        return img
    resize_shape = (OUTPUT_IMAGE_SIZE,int((OUTPUT_IMAGE_SIZE/height)*width)) if height>=width else (int((OUTPUT_IMAGE_SIZE/width)*height), OUTPUT_IMAGE_SIZE)
    padding = (0, 0, OUTPUT_IMAGE_SIZE - resize_shape[1], OUTPUT_IMAGE_SIZE - resize_shape[0])
    transform = transforms.Compose([
        transforms.Resize(resize_shape),
        transforms.Pad(padding=padding, fill=0),
        transforms.ToTensor(),
        transforms.ToPILImage()
    ])
    return transform(img)

main_dir = "./stanford_cars/"
out_dir = "./transformed_stanford_cars_rough/"



for split in ["train", "test"]:
    split_in = os.path.join(main_dir, split)
    split_out = os.path.join(out_dir, split)

    for root, dirs, files in os.walk(split_in):
        rel_path = os.path.relpath(root, split_in)
        target_root = os.path.join(split_out, rel_path)
        os.makedirs(target_root, exist_ok=True)

        for f in files:
            if not f.lower().endswith(".jpg"):
                continue

            in_path = os.path.join(root, f)
            out_path = os.path.join(target_root, f)

            img = Image.open(in_path)
            img_t = transform_image(img)
            img_t.save(out_path)
