### Dataloader

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import json
from torchvision import transforms
import cv2


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, json_dir):
        self.image_dir = image_dir
        self.json_dir = json_dir
        self.json_data = [f for f in os.listdir(image_dir) if f.endswith(('.json', '.json'))]
        
    def __len__(self):
        return len(self.json_data)
    
    def load_mask_from_json(self, json_path):
        with open(json_path, 'r') as f:
            data = json.load(f)
        
        mask = np.zeros((data['imageHeight'], data['imageWidth']), dtype=np.uint8)
        for shape in data['shapes']:
            points = np.array(shape['points'], dtype=np.int32)
            mask = cv2.fillPoly(mask, [points], 1)
        return mask

    def __getitem__(self, idx):
        # Load mask
        data_name = self.json_data[idx]
        json_path = os.path.join(self.json_dir, data_name)
        mask = self.load_mask_from_json(json_path)

        img_path = os.path.join(self.image_dir, data_name.replace( '.json', '.jpg'))
        image = Image.open(img_path).convert('RGB')

        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        
        image = transform(image)
        mask = cv2.resize(mask, (256, 256))
        mask = torch.from_numpy(mask).float()
        
        return image, mask

### Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetWithBN(nn.Module):
    def __init__(self):
        super(UNetWithBN, self).__init__()
        
        # Encoder Block 1
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 12, 3, padding=1),
            nn.BatchNorm2d(12),
            nn.ReLU(inplace=True),
        )
        
        # Encoder Block 2
        self.enc2 = nn.Sequential(
            nn.Conv2d(12, 24, 3, padding=1),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True),
        )
        
        # Encoder Block 3
        self.enc3 = nn.Sequential(
            nn.Conv2d(24, 34, 3, padding=1),
            nn.BatchNorm2d(34),
            nn.ReLU(inplace=True),
        )
        
        # Decoder Block 3
        self.dec3 = nn.Sequential(
            nn.Conv2d(34, 24, 3, padding=1),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True),
        )
        
        # Decoder Block 2
        self.dec2 = nn.Sequential(
            nn.Conv2d(24, 12, 3, padding=1),
            nn.BatchNorm2d(12),
            nn.ReLU(inplace=True),
        )
        
        # Final layer
        self.final = nn.Conv2d(12, 1, 1)
        
        # Pooling and Upsampling
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool(e1)
        
        e2 = self.enc2(p1)
        p2 = self.pool(e2)
        
        e3 = self.enc3(p2)
        
        # Decoder
        d3 = self.dec3(self.up(e3))
        d2 = self.dec2(self.up(d3))
        
        out = self.final(d2)
        return out


### Training

In [None]:
# Training function with added metrics
def train_model(image_dir, json_dir, num_epochs=10):
    # Create dataset and dataloader
    dataset = SegmentationDataset(image_dir, json_dir)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Initialize model, loss, and optimizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetWithBN().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for i, (images, masks) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.to(device).unsqueeze(1)
 
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Print batch progress
            if i % 10 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {i}/{len(dataloader)}, '
                      f'Loss: {loss.item():.4f}')
        
        epoch_loss = running_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {epoch_loss:.4f}')
        
        # Learning rate scheduling
        scheduler.step(epoch_loss)
        
        # Save best model
        # if epoch_loss < best_loss:
        #     best_loss = epoch_loss
        #     torch.save(model.state_dict(), 'best_model.pth')
    
    return model

### Infernece

In [None]:
def predict(model, image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        prediction = torch.sigmoid(output) > 0.5
    
    return prediction.cpu().squeeze().numpy()

In [None]:
%%time
pred_dir = 'sample image'
mask = predict(model, pred_dir)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(plt.imread(pred_dir))

In [None]:
plt.imshow(mask)

In [None]:
torch.save(model.state_dict(), 'best_model.pth')

### infernce on video

In [None]:
import cv2
import torch
import numpy as np
from torchvision import transforms
from PIL import Image

def process_video(model_path, video_path, output_path):
    # Load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetWithBN().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    cap = cv2.VideoCapture(video_path)    

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break


        frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        input_tensor = transform(frame_pil).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)
            pred_mask = torch.sigmoid(output) > 0.5
            pred_mask = pred_mask.squeeze().cpu().numpy()

        pred_mask = cv2.resize(pred_mask.astype(np.uint8), (frame_width, frame_height))

        colored_mask = np.zeros_like(frame)
        colored_mask[pred_mask == 1] = [0, 255, 0]  # Green color for segmentation

        alpha = 0.5
        output_frame = cv2.addWeighted(frame, 1, colored_mask, alpha, 0)

        out.write(output_frame)


    cap.release()
    out.release()

if __name__ == "__main__":
    model_path =   # Path to your trained model
    video_path =  # Path to your input video
    output_path =  # Path for saving the output video
    
    process_video(model_path, video_path, output_path)
