In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
#import torchvision.transforms as transforms
from torchvision import transforms

from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
import torchvision.transforms.functional as TF
from einops import rearrange
from datasets import load_dataset
from io import BytesIO
import requests
import numpy as np
from PIL import Image, ImageDraw
#from huggingface_hub import login
#login("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")

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from PIL import Image
import torch
from transformers import DetrImageProcessor
from datasets import load_dataset
from transformers import DetrImageProcessor, DetrForObjectDetection, DetrConfig
import torch
from torch import nn
from torch.utils.data import DataLoader

# Initialize the processor
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')

# Load the dataset
dataset = load_dataset("detection-datasets/coco", split='train[:1%]')

def transform_data(sample):
    image = sample['image']
    inputs = processor(images=image, return_tensors="pt")

    # Directly use image dimensions for verification
    actual_width, actual_height = image.size
    print("Actual dimensions:", actual_width, actual_height)

    # Convert bounding box format [x_min, y_min, width, height] to [x0, y0, x1, y1] without scaling
    bboxes = []
    for box in sample['objects']['bbox']:
        x0 = box[0]
        y0 = box[1]
        x1 = x0 + box[2]
        y1 = y0 + box[3]
        print(f"Box: ({x0}, {y0}, {x1}, {y1})")  # Debug print
        bboxes.append([x0, y0, x1, y1])
    tensor_boxes = torch.tensor(bboxes, dtype=torch.float32)

    labels = torch.tensor(sample['objects']['category'], dtype=torch.long)

    return inputs, tensor_boxes, labels

# Example usage for demonstration
sample_data = dataset[0]
print(sample_data)
'''inputs, boxes, labels = transform_data(sample_data)

import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_image_with_boxes(image, boxes):
    fig, ax = plt.subplots(1)
    ax.imshow(image)
    for box in boxes:
        # Adjust rectangle drawing based on [x0, y0, x1, y1]
        rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
    plt.show()

inputs, boxes, labels = transform_data(sample_data)
show_image_with_boxes(sample_data['image'], boxes.numpy())  # Ensure boxes are converted to numpy if needed

'''

# Helper modules
class PatchEmbed(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # Flatten image patches into a sequence of embeddings
        x = self.proj(x)
        print("Shape after convolution:", x.shape)

        # We rearrange assuming the conv output is (batch, embed_dim, height // patch_size, width // patch_size)
        # No need to pass patch_size dynamically to einops
        return rearrange(x, 'b e h w -> b (h w) e')

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_layers, num_heads):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=embed_dim*4)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, x):
        return self.encoder(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block = ConvBlock(in_channels, out_channels)
        self.downsample = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.conv_block(x)
        x_downsampled = self.downsample(x)
        return x, x_downsampled

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, skip_channels, kernel_size=2, stride=2)  # output channels match skip_channels
        self.conv_block = ConvBlock(in_channels + skip_channels, out_channels)  # adjust the input channels of ConvBlock

    def forward(self, x, skip):
        print(f"Input to Decoder Block: {x.shape}")
        x = self.upconv(x)
        print(f"x after upconv(x): {x.shape}")
        # Optional: Resize x to the size of skip before concatenating if there's a size mismatch
        if x.shape[2:] != skip.shape[2:]:
            x = TF.resize(x, size=skip.shape[2:])
        x = torch.cat([x, skip], dim=1)
        print(f"Shape after concat in decoder: {x.shape}")
    
        x = self.conv_block(x)
        print(f"Output of Decoder Block: {x.shape}")

        return x




class HybridUNet(nn.Module):
    def __init__(self, in_channels, num_classes, patch_size=16, embed_dim=768, num_heads=8, num_enc_layers=4, num_queries=100):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
        self.transformer_encoder = TransformerEncoder(embed_dim, num_enc_layers, num_heads)

        self.encoder1 = EncoderBlock(embed_dim, 64)
        self.encoder2 = EncoderBlock(64, 128)
        self.encoder3 = EncoderBlock(128, 256)
        self.encoder4 = EncoderBlock(256, 512)

        self.decoder1 = DecoderBlock(512, 256, 256)
        self.decoder2 = DecoderBlock(256, 128, 128)
        self.decoder3 = DecoderBlock(128, 64, 64)
        self.decoder4 = DecoderBlock(64, embed_dim, 64)

        self.classifier = nn.Conv2d(64, num_classes + 1, kernel_size=1)  # +1 for background class
        self.bbox_predictor = nn.Conv2d(64, 4, kernel_size=1)

    def forward(self, x):
        x = self.patch_embed(x)
        x = rearrange(x, 'b (h w) e -> b e h w', h=int(x.shape[1]**0.5))
        enc_features = []
        
        # Encoding
        skip, x = self.encoder1(x); enc_features.append(skip)
        skip, x = self.encoder2(x); enc_features.append(skip)
        skip, x = self.encoder3(x); enc_features.append(skip)
        skip, x = self.encoder4(x); enc_features.append(skip)
        
        # Decoding
        x = self.decoder1(x, enc_features.pop())
        x = self.decoder2(x, enc_features.pop())
        x = self.decoder3(x, enc_features.pop())
        x = self.decoder4(x, enc_features.pop())

        # Prediction Heads
        logits = self.classifier(x)
        bbox_outputs = self.bbox_predictor(x)
        print("HYBRID UNETVIT OUTPUT:")
        print(f"hybrid unet vit logits shape: {logits.shape}")
        print(f"hybrid unet vit pred boxes shape: {bbox_outputs.shape}")

        return {'logits': logits, 'pred_boxes': bbox_outputs}



def generate_segmentation_mask(anns, height, width):
    """
    Generate a binary mask for segmentation from COCO annotations, assuming polygon annotations.
    """
    mask = np.zeros((height, width), dtype=np.uint8)
    for ann in anns:
        if 'segmentation' in ann and isinstance(ann['segmentation'], list):  # Ensure there's a segmentation and it's a list of polygons
            for seg in ann['segmentation']:
                # Each 'seg' is a list of coordinates like [x1, y1, x2, y2, ..., xn, yn]
                # It needs to be reshaped into a list of tuples [(x1, y1), (x2, y2), ..., (xn, yn)]
                if len(seg) % 2 == 0:  # Ensure the list is even-length
                    poly = np.array(seg).reshape((-1, 2))
                    img = Image.new('L', (width, height), 0)
                    ImageDraw.Draw(img).polygon(list(map(tuple, poly)), outline=1, fill=1)
                    mask = np.logical_or(mask, np.array(img, dtype=bool))
    
    return mask.astype(np.uint8)


def train_val_split(dataset, val_split=0.2):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(val_split * dataset_size))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    return torch.utils.data.Subset(dataset, train_indices), torch.utils.data.Subset(dataset, val_indices)

def validate_model(model, dataloader, loss_fn, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            outputs_upsampled = F.interpolate(outputs, size=(224, 224), mode='bilinear', align_corners=False)
            masks = torch.squeeze(masks, 1).long()

            loss = loss_fn(outputs_upsampled, masks)
            val_loss += loss.item()

    return val_loss / len(dataloader)

import torchvision.transforms.functional as TF

class CustomCocoDataset(Dataset):
    def __init__(self, dataset, target_size=(800, 800)):
        self.dataset = dataset
        self.target_size = target_size

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample['image']

        # Ensure image is in RGB
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Resize image to target size using the correct torchvision function
        image = TF.resize(image, self.target_size)

        inputs = processor(images=image, return_tensors="pt")
        
        # Scale bounding boxes according to the new image size
        scale_x = self.target_size[0] / sample['width']
        scale_y = self.target_size[1] / sample['height']
        bboxes = []
        for box in sample['objects']['bbox']:
            x0 = box[0] * scale_x
            y0 = box[1] * scale_y
            x1 = (box[0] + box[2]) * scale_x
            y1 = (box[1] + box[3]) * scale_y
            bboxes.append([x0, y0, x1, y1])
        tensor_boxes = torch.tensor(bboxes, dtype=torch.float32)

        labels = torch.tensor(sample['objects']['category'], dtype=torch.long)

        return {'inputs': inputs['pixel_values'].squeeze(0), 'boxes': tensor_boxes, 'labels': labels}

# Adjust collate function as needed:
def collate_fn(batch):
    inputs = torch.stack([item['inputs'] for item in batch])
    max_boxes = max(len(item['boxes']) for item in batch)

    padded_boxes = torch.zeros((len(batch), max_boxes, 4))
    box_masks = torch.zeros((len(batch), max_boxes), dtype=torch.bool)
    padded_labels = torch.zeros((len(batch), max_boxes), dtype=torch.long)  # Adjust if -1 creates problems

    for i, item in enumerate(batch):
        num_boxes = item['boxes'].shape[0]
        padded_boxes[i, :num_boxes] = item['boxes']
        padded_labels[i, :num_boxes] = item['labels']
        box_masks[i, :num_boxes] = 1
    print({'inputs': inputs.shape, 'boxes': padded_boxes.shape, 'labels': padded_labels.shape, 'box_masks': box_masks.shape})
    return {'inputs': inputs, 'boxes': padded_boxes, 'labels': padded_labels, 'box_masks': box_masks}


coco_dataset = CustomCocoDataset(dataset)


# Create training and validation datasets
train_dataset, val_dataset = train_val_split(dataset, val_split=0.2)

# Now use this custom collate function in your DataLoader
data_loader = DataLoader(coco_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# Create training and validation dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
# Hyperparameters
num_classes = 80
lr = 1e-4  
num_epochs = 20 
# Example initialization
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

model = HybridUNet(in_channels=3, num_classes=80, patch_size=16, embed_dim=768, num_heads=8, num_enc_layers=4).to(device)

# Loss Function and Optimizer
loss_fn = nn.CrossEntropyLoss()  # Or other suitable segmentation loss
optimizer = torch.optim.Adam(model.parameters(), lr=lr) 

# Loading a saved model
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# Define loss functions
classification_loss_fn = nn.CrossEntropyLoss()
bbox_regression_loss_fn = nn.SmoothL1Loss()

# Move model to GPU if available
model.to(device)
# Define the training loop
all_labels = torch.cat([item['labels'] for item in coco_dataset])
unique_labels = torch.unique(all_labels)
print(f"Unique labels in the dataset: {unique_labels}")

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for batch_idx, batch in enumerate(data_loader):
        inputs = batch['inputs'].to(device)
        boxes = batch['boxes'].to(device)
        labels = batch['labels'].to(device)
        box_masks = batch['box_masks'].to(device)
        
        print(f"inputs size: {inputs.shape}")
        print(f"boxes size: {boxes.shape}")
        print(f"labels size: {labels.shape}")
        print(f"box_masks size: {box_masks.shape}")

        # Forward pass
        outputs = model(inputs)
        print(f"Output logits size: {outputs['logits'].shape}")
        # Prepare targets for all non-padded areas
        targets = []
        for i in range(inputs.size(0)):

            active_indices = box_masks[i]

            active_boxes = boxes[i][active_indices]

            active_labels = labels[i][active_indices]

            targets.append({'labels': active_labels, 'boxes': active_boxes})
        
        # Compute the classification loss
        logits = outputs['logits']
        max_target_labels = max(len(target['labels']) for target in targets)
        num_queries = logits.size(1)
        padded_logits = torch.zeros((logits.size(0), num_queries, logits.size(2), logits.size(3)), device=device) # Include width dimension
        for i, target in enumerate(targets):
            padded_logits[i, :len(target['labels']), :, :] = logits[i, :len(target['labels']), :, :] # Add the width dimension
        logits = padded_logits
        print(f"Padded logits size: {logits.shape}")

        target_labels = torch.cat([target['labels'] for target in targets], dim=0)

        # Resize logits to match the number of target labels
        num_target_labels = target_labels.size(0)
        print(f"num target labels: {num_target_labels}")
        logits_flat = logits.view(-1, logits.shape[-1])  # Shape: (batch_size * num_queries, num_classes)
        print(f"logits_flat: {logits_flat.shape}")
        
        target_labels_flat = target_labels.view(-1)      # Shape: (batch_size * num_queries)
        print(f"target_labels_flat: {target_labels_flat.shape}")

        # Slice the logits tensor to match the size of the target_labels tensor
        #logits_flat = logits_flat[:target_labels_flat.size(0)]
        logits_flat = logits_flat[:num_target_labels]

        print(f"Shape of logits_flat: {logits_flat.shape}")
        print(f"Shape of target_labels_flat: {target_labels_flat.shape}")
        print(f"Unique values in target_labels_flat: {torch.unique(target_labels_flat)}")

        classification_loss = classification_loss_fn(logits_flat, target_labels_flat)

        
        # Compute the bbox regression loss
        pred_boxes = outputs['pred_boxes']
        target_boxes = torch.cat([target['boxes'] for target in targets])

        # Compute the bbox regression loss
        pred_boxes_flat = pred_boxes.view(-1, 4)  # Shape: (batch_size * num_queries, 4)
        target_boxes_flat = target_boxes.view(-1, 4)  # Shape: (batch_size * num_queries, 4)

        # Slice the pred_boxes_flat tensor to match the size of the target_boxes_flat tensor
        pred_boxes_flat = pred_boxes_flat[:target_boxes_flat.size(0)]

        # Compute L1 loss (smooth L1 loss) for bbox regression
        loss_bbox = nn.SmoothL1Loss(reduction='sum')(pred_boxes_flat, target_boxes_flat)

        #bbox_regression_loss = bbox_regression_loss_fn(pred_boxes, target_boxes)
        
        # Compute total loss
        loss = classification_loss + loss_bbox
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(data_loader)}], Loss: {loss.item()}")

    print(f"Epoch [{epoch+1}/{num_epochs}], Total Loss: {total_loss / len(data_loader)}")

print("Training complete!")