In [2]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
from ultralytics import YOLO
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
import numpy as np
import cv2
import pickle

ModuleNotFoundError: No module named 'ultralytics'

In [None]:
torch.backends.cudnn.benchmark = True
abs_path = '/content/drive/Othercomputers/MacBook Pro (Personal)/Documents/COLUMBIA UNIVERSITY/MSCS/Research/Knolling Bot/Preliminary Pipeline/'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
yolo_model = YOLO('yolov8x.pt').to(device)

In [None]:
def load_images(directory):
    return [Image.open(os.path.join(directory, filename)) for filename in sorted(os.listdir(directory)) if filename.endswith(('.png', '.jpg'))]

messy_images = load_images(os.path.join(abs_path, 'data/images_before_small/'))
tidy_images = load_images(os.path.join(abs_path, 'data/images_after_small/'))

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [None]:
class MessyTidyDataset(Dataset):
    def __init__(self, messy_images, tidy_images, transform=None):
        self.messy_images = messy_images
        self.tidy_images = tidy_images
        self.transform = transform

    def __len__(self):
        return len(self.messy_images)

    def __getitem__(self, idx):
        messy_image = self.messy_images[idx]
        tidy_image = self.tidy_images[idx]

        if self.transform:
            messy_image = self.transform(messy_image)
            tidy_image = self.transform(tidy_image)

        return messy_image, tidy_image

dataset = MessyTidyDataset(messy_images, tidy_images, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
unet = UNet2DModel(
    sample_size=128,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
)

model = DDPMPipeline(unet=unet, scheduler=DDPMScheduler(num_train_timesteps=1000)).to(device)
optimizer = Adam(model.unet.parameters(), lr=5e-4, betas=(0.5, 0.999))
scaler = torch.cuda.amp.GradScaler()

In [None]:
def extract_objects(yolo_model, image, conf_threshold=0.5, iou_threshold=0.45, imgsz=128):
    results = yolo_model(image, conf=conf_threshold, iou=iou_threshold, imgsz=imgsz)
    bboxes = results[0].boxes.xyxy.cpu().numpy()
    objects = []
    for bbox in bboxes:
        x1, y1, x2, y2 = map(int, bbox[:4])
        obj = image.crop((x1, y1, x2, y2))
        objects.append((obj, (x1, y1, x2, y2)))
    return objects

In [None]:
def composite_tidy_image(objects, image_size=(128, 128)):
    tidy_image = Image.new('RGB', image_size, (255, 255, 255))
    used_positions = []
    placed_objects = set()
    max_attempts = 100
    for idx, (obj, bbox) in enumerate(objects):
        attempts = 0
        while attempts < max_attempts:
            new_x, new_y = np.random.randint(0, image_size[0] - obj.size[0]), np.random.randint(0, image_size[1] - obj.size[1])
            new_bbox = (new_x, new_y, new_x + obj.size[0], new_y + obj.size[1])
            if not any([iou(new_bbox, used) > 0 for used in used_positions]):
                tidy_image.paste(obj, (new_x, new_y))
                used_positions.append(new_bbox)
                placed_objects.add(idx)
                break
            attempts += 1
        if idx not in placed_objects:
            new_x, new_y = np.random.randint(0, image_size[0] - obj.size[0]), np.random.randint(0, image_size[1] - obj.size[1])
            tidy_image.paste(obj, (new_x, new_y))
            used_positions.append((new_x, new_y, new_x + obj.size[0], new_y + obj.size[1]))
    return tidy_image

In [None]:
def iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxB[3], boxB[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

In [None]:
@torch.amp.autocast('cuda')
def train_model(unet, dataloader, yolo_model, optimizer, device, scaler, num_epochs=10):
    for epoch in range(num_epochs):
        unet.train()
        for batch_idx, (messy_image, tidy_image) in enumerate(dataloader):
            messy_image, tidy_image = messy_image.to(device, non_blocking=True), tidy_image.to(device, non_blocking=True)
            optimizer.zero_grad()

            if batch_idx == 0:
                plt.imshow(transforms.ToPILImage()(messy_image[0].cpu()))
                plt.axis('off')
                plt.show()
                objects = extract_objects(yolo_model, transforms.ToPILImage()(messy_image[0].cpu()))
            else:
                plt.imshow(transforms.ToPILImage()(predicted_image.squeeze(0).cpu()))
                plt.axis('off')
                plt.show()
                objects = extract_objects(yolo_model, transforms.ToPILImage()(predicted_image.squeeze(0).cpu()))

            predicted_image = composite_tidy_image(objects)
            predicted_image = transform(predicted_image).unsqueeze(0).to(device)
            print(f"After transform - Min: {predicted_image.min().item()}, Max: {predicted_image.max().item()}, Mean: {predicted_image.mean().item()}")

            noise = 0.02 * torch.randn_like(predicted_image).to(device)
            noisy_image = predicted_image
            timestep = torch.randint(0, 1000, (predicted_image.shape[0],), device=device).long()

            with torch.amp.autocast('cuda'):
                predicted_image = unet(noisy_image, timestep).sample
                print(f"After U-Net forward - Min: {predicted_image.min().item()}, Max: {predicted_image.max().item()}, Mean: {predicted_image.mean().item()}")
                loss = torch.nn.functional.mse_loss(predicted_image, tidy_image)
                print(f"Loss: {loss}")

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

train_model(unet, dataloader, yolo_model, optimizer, device, scaler)
model.save_pretrained(os.path.join(abs_path, 'trained_model/'))

In [None]:
def generate_tidy_image(model, yolo_model, messy_image, transform, device):
    with torch.no_grad():
        objects = extract_objects(yolo_model, messy_image)
        messy_image_tensor = transform(messy_image).unsqueeze(0).to(device)
        timestep = torch.tensor([0], device=device).long()
        tidy_image_generated = model.unet(messy_image_tensor, timestep).sample
        tidy_image_generated = tidy_image_generated.squeeze().cpu()
        tidy_image_pil = transforms.ToPILImage()(tidy_image_generated)

        return tidy_image_pil

In [None]:
img_id = 3
generated_image = generate_tidy_image(model, yolo_model, messy_images[img_id], transform, device)
plt.imshow(generated_image)
plt.axis('off')
plt.show()
plt.imshow(messy_images[img_id])
plt.axis('off')
plt.show()
plt.imshow(tidy_images[img_id])
plt.axis('off')
plt.show()