In [1]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg19, VGG19_Weights
from PIL import Image
from tqdm.notebook import tqdm
import os
import random
from skimage.exposure import match_histograms
import matplotlib.pyplot as plt

# Create directories to save outputs
os.makedirs('outputs', exist_ok=True)
os.makedirs('models', exist_ok=True)

# --- Configuration ---
CONFIG = {
    "VIDEO_PATH": "video.mp4",
    "IMG_SIZE": 512,
    "BATCH_SIZE": 2,
    "EPOCHS": 50,
    "LEARNING_RATE": 1e-4,
    "VALIDATION_SPLIT": 0.2,
    "NUM_CONTEXT_FRAMES": 4,
    "FRAME_SKIP": 3,
    "MASK_SHAPES": 5,
    "MAX_MASK_SIZE": 150,
    "L1_LOSS_WEIGHT": 1.0,
    "PERCEPTUAL_LOSS_WEIGHT": 0.05,
    "INITIAL_KEYFRAME_SECONDS": 5,  # New parameter for initial keyframe time window
}

# Set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

def find_crop_box(video_path, threshold=10):
    """Finds the bounding box of the non-black content in a video frame."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError("Cannot open video file")
    ret, frame = cap.read()
    cap.release()
    if not ret:
        raise IOError("Cannot read frame from video")
        
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    rows, cols = np.where(gray > threshold)
    if len(rows) == 0 or len(cols) == 0:
        return 0, frame.shape[0], 0, frame.shape[1]
    
    return np.min(rows), np.max(rows), np.min(cols), np.max(cols)

def generate_random_mask(height, width, num_shapes, max_size):
    """Generates a random mask with various shapes."""
    mask = np.zeros((height, width, 1), dtype=np.uint8)
    for _ in range(num_shapes):
        shape_type = random.randint(0, 1)
        x1 = random.randint(0, width - 1)
        y1 = random.randint(0, height - 1)
        size_x = random.randint(20, max_size)
        size_y = random.randint(20, max_size)
        
        if shape_type == 0:
            x2 = min(x1 + size_x, width)
            y2 = min(y1 + size_y, height)
            cv2.rectangle(mask, (x1, y1), (x2, y2), (255), -1)
        else:
            radius = random.randint(10, max_size // 2)
            cv2.circle(mask, (x1, y1), radius, (255), -1)
            
    return mask

def calculate_quality_score(frame_tensor):
    """Calculates a quality score for a frame based on sharpness."""
    gray_frame = frame_tensor[0] * 0.299 + frame_tensor[1] * 0.587 + frame_tensor[2] * 0.114
    laplacian_kernel = torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=torch.float32).to(DEVICE)
    laplacian = nn.functional.conv2d(gray_frame.unsqueeze(0).unsqueeze(0), laplacian_kernel, padding=1)
    score = laplacian.var()
    return score.item()

class VideoFrameDataset(Dataset):
    def __init__(self, frames, keyframe, crop_box, config):
        self.frames = frames
        self.keyframe = keyframe
        self.crop_box = crop_box
        self.config = config
        self.img_size = config["IMG_SIZE"]
        self.num_context = config["NUM_CONTEXT_FRAMES"]
        self.frame_skip = config["FRAME_SKIP"]
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.frames) - (self.num_context * self.frame_skip)

    def __getitem__(self, idx):
        target_idx = idx + (self.num_context * self.frame_skip)
        
        target_frame_pil = Image.fromarray(cv2.cvtColor(self.frames[target_idx], cv2.COLOR_BGR2RGB))
        target_tensor = self.transform(target_frame_pil)
        
        context_tensors = []
        for i in range(self.num_context):
            context_idx = target_idx - ((i + 1) * self.frame_skip)
            frame_pil = Image.fromarray(cv2.cvtColor(self.frames[context_idx], cv2.COLOR_BGR2RGB))
            context_tensors.append(self.transform(frame_pil))
        
        keyframe_pil = Image.fromarray(cv2.cvtColor(self.keyframe, cv2.COLOR_BGR2RGB))
        target_pil = Image.fromarray(cv2.cvtColor(self.frames[target_idx], cv2.COLOR_BGR2RGB))
        keyframe_np = np.array(keyframe_pil)
        target_np = np.array(target_pil)
        matched_keyframe = match_histograms(keyframe_np, target_np, channel_axis=-1)
        keyframe_pil = Image.fromarray(matched_keyframe)
        keyframe_tensor = self.transform(keyframe_pil) * 0.5
        
        mask_np = generate_random_mask(self.img_size, self.img_size, self.config["MASK_SHAPES"], self.config["MAX_MASK_SIZE"])
        mask_tensor = transforms.functional.to_tensor(mask_np)

        masked_target_tensor = target_tensor * (1 - mask_tensor)
        
        input_stack = context_tensors[::-1] + [masked_target_tensor, keyframe_tensor]
        input_tensor = torch.cat(input_stack, dim=0)

        return input_tensor, mask_tensor, target_tensor

    def update_keyframe(self, new_keyframe):
        self.keyframe = new_keyframe

class InpaintingLoss(nn.Module):
    def __init__(self, perceptual_weight=0.1, l1_weight=1.0):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.perceptual_loss = PerceptualLoss()
        self.perceptual_weight = perceptual_weight
        self.l1_weight = l1_weight

    def forward(self, generated_img, ground_truth_img, masks):
        masked_loss = self.l1_loss(generated_img * masks, ground_truth_img * masks)
        unmasked_loss = self.l1_loss(generated_img * (1 - masks), ground_truth_img * (1 - masks))
        perceptual = self.perceptual_loss(generated_img, ground_truth_img)
        
        total_loss = (self.l1_weight * masked_loss) + (self.perceptual_weight * perceptual) + (0.1 * unmasked_loss)
        return total_loss

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(weights=VGG19_Weights.DEFAULT).features.to(DEVICE).eval()
        self.features = nn.Sequential(*list(vgg.children())[:35])
        for param in self.features.parameters():
            param.requires_grad = False
        self.loss = nn.L1Loss()

    def forward(self, generated, ground_truth):
        gen_features = self.features(generated)
        gt_features = self.features(ground_truth)
        return self.loss(gen_features, gt_features)

class FFC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout
        
        self.in_local = int(self.in_channels * (1 - ratio_gin))
        self.in_global = self.in_channels - self.in_local
        self.out_local = int(self.out_channels * (1 - ratio_gout))
        self.out_global = self.out_channels - self.out_local

        if self.in_local > 0 and self.out_local > 0:
            self.conv_l = nn.Conv2d(self.in_local, self.out_local, kernel_size, stride, padding, bias=False)
        else:
            self.conv_l = None
        
        if self.in_global > 0 and self.out_global > 0:
            self.conv_g = nn.Conv2d(self.in_global * 2, self.out_global * 2, kernel_size=1, bias=False)
        else:
            self.conv_g = None

    def forward(self, x):
        batch_size, _, height, width = x.shape
        
        if self.in_global == 0 and self.conv_l is not None:
            return self.conv_l(x)
        
        if self.in_local == 0 and self.conv_g is not None:
            fft_g = torch.fft.rfft2(x, norm='ortho')
            fft_g_real_imag = torch.cat([fft_g.real, fft_g.imag], dim=1)
            ffc_out_real_imag = self.conv_g(fft_g_real_imag)
            ffc_out_real, ffc_out_imag = torch.split(ffc_out_real_imag, self.out_global, dim=1)
            fft_g = torch.complex(ffc_out_real, ffc_out_imag)
            return torch.fft.irfft2(fft_g, s=(height, width), norm='ortho')
        
        x_l, x_g = torch.split(x, [self.in_local, self.in_global], dim=1)

        out_l = self.conv_l(x_l) if self.conv_l is not None else None

        out_g = None
        if self.conv_g is not None:
            fft_g = torch.fft.rfft2(x_g, norm='ortho')
            fft_g_real_imag = torch.cat([fft_g.real, fft_g.imag], dim=1)
            ffc_out_real_imag = self.conv_g(fft_g_real_imag)
            ffc_out_real, ffc_out_imag = torch.split(ffc_out_real_imag, self.out_global, dim=1)
            fft_g = torch.complex(ffc_out_real, ffc_out_imag)
            out_g = torch.fft.irfft2(fft_g, s=(height, width), norm='ortho')

        if out_l is not None and out_g is not None:
            return torch.cat([out_l, out_g], dim=1)
        elif out_l is not None:
            return out_l
        elif out_g is not None:
            return out_g
        else:
            return torch.zeros(batch_size, self.out_channels, height, width, device=x.device)

class FFC_BN_ACT(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin=0.75, ratio_gout=0.75, stride=1, padding=0):
        super().__init__()
        self.ffc = FFC(in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.ffc(x)
        x = self.bn(x)
        x = self.act(x)
        return x

class LaMaUNet(nn.Module):
    def __init__(self, n_channels=18, n_classes=3, bilinear=True):
        super().__init__()
        ratio = 0.75

        self.inc = nn.Sequential(
            FFC_BN_ACT(n_channels, 64, kernel_size=3, padding=1, ratio_gin=0, ratio_gout=0),
            FFC_BN_ACT(64, 64, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            FFC_BN_ACT(64, 128, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(128, 128, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            FFC_BN_ACT(128, 256, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(256, 256, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )
        self.bottleneck = nn.Sequential(
            nn.MaxPool2d(2),
            FFC_BN_ACT(256, 512, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(512, 512, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(512, 256, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )

        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear' if bilinear else 'nearest')
        self.conv1 = nn.Sequential(
            FFC_BN_ACT(512, 256, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(256, 128, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
        )

        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear' if bilinear else 'nearest')
        self.conv2 = nn.Sequential(
            FFC_BN_ACT(256, 128, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(128, 64, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
        )

        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear' if bilinear else 'nearest')
        self.conv3 = nn.Sequential(
            FFC_BN_ACT(128, 64, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            nn.Conv2d(64, n_classes, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        
        x_bottle = self.bottleneck(x3)

        u1 = self.up1(x_bottle)
        c1 = torch.cat([u1, x3], dim=1)
        d1 = self.conv1(c1)

        u2 = self.up2(d1)
        c2 = torch.cat([u2, x2], dim=1)
        d2 = self.conv2(c2)

        u3 = self.up3(d2)
        c3 = torch.cat([u3, x1], dim=1)
        d3 = self.conv3(c3)

        return torch.sigmoid(d3)

print("Loading and preprocessing video frames...")
y_min, y_max, x_min, x_max = find_crop_box(CONFIG["VIDEO_PATH"])
cap = cv2.VideoCapture(CONFIG["VIDEO_PATH"])
all_frames = []
while True:
    ret, frame = cap.read()
    if not ret:
        break
    cropped_frame = frame[y_min:y_max, x_min:x_max]
    resized_frame = cv2.resize(cropped_frame, (CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]), interpolation=cv2.INTER_AREA)
    all_frames.append(resized_frame)
cap.release()
print(f"Loaded {len(all_frames)} frames.")

# Get frame rate to estimate initial keyframe range
fps = cap.get(cv2.CAP_PROP_FPS) if cap.isOpened() else 30.0  # Default to 30 FPS if unavailable
initial_frame_count = int(CONFIG["INITIAL_KEYFRAME_SECONDS"] * fps)
initial_frames = all_frames[:min(initial_frame_count, len(all_frames))]

split_idx = int(len(all_frames) * (1 - CONFIG["VALIDATION_SPLIT"]))
train_frames = all_frames[:split_idx]
val_frames = all_frames[split_idx:]

print("Finding initial best keyframe from first few seconds...")
best_keyframe = None
best_score = -1
for frame_np in initial_frames:
    frame_tensor = transforms.ToTensor()(Image.fromarray(cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB))).to(DEVICE)
    score = calculate_quality_score(frame_tensor)
    if score > best_score:
        best_score = score
        best_keyframe = frame_np
print(f"Initial keyframe found with sharpness score: {best_score:.2f} from first {CONFIG['INITIAL_KEYFRAME_SECONDS']} seconds")

train_dataset = VideoFrameDataset(train_frames, best_keyframe, (y_min, y_max, x_min, x_max), CONFIG)
val_dataset = VideoFrameDataset(val_frames, best_keyframe, (y_min, y_max, x_min, x_max), CONFIG)

train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=2)

model = LaMaUNet(n_channels=18, n_classes=3).to(DEVICE)
criterion = InpaintingLoss(perceptual_weight=CONFIG["PERCEPTUAL_LOSS_WEIGHT"], l1_weight=CONFIG["L1_LOSS_WEIGHT"]).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=CONFIG["LEARNING_RATE"])

for epoch in range(CONFIG["EPOCHS"]):
    model.train()
    train_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']} [T]")
    for inputs, masks, truths in progress_bar:
        inputs, masks, truths = inputs.to(DEVICE), masks.to(DEVICE), truths.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs, truths, masks)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
        
    avg_train_loss = train_loss / len(train_loader)

    model.eval()
    val_loss = 0.0
    new_best_keyframe_found = False
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']} [V]")
        for i, (inputs, masks, truths) in enumerate(progress_bar):
            inputs, masks, truths = inputs.to(DEVICE), masks.to(DEVICE), truths.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, truths, masks)
            val_loss += loss.item()

            # Check for new best keyframe from the first sample of the batch
            if i == 0:
                for j in range(truths.size(0)):
                    score = calculate_quality_score(truths[j])
                    if score > best_score:
                        best_score = score
                        best_keyframe_tensor = truths[j].cpu()
                        best_keyframe = (best_keyframe_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
                        best_keyframe = cv2.cvtColor(best_keyframe, cv2.COLOR_RGB2BGR)
                        new_best_keyframe_found = True

            # Save fixed sample from the first batch, first sample
            if i == 0:
                j = 0  # Explicitly use the first sample
                truth_img = (truths[j].cpu() * 255).clamp(0, 255).byte()
                masked_input = (truths[j].cpu() * (1 - masks[j].cpu()) * 255).clamp(0, 255).byte()
                output_img = (outputs[j].cpu() * 255).clamp(0, 255).byte()
                keyframe_img = (transforms.ToTensor()(Image.fromarray(cv2.cvtColor(best_keyframe, cv2.COLOR_BGR2RGB))).cpu() * 255).clamp(0, 255).byte()

                # Create 2x2 grid
                grid = np.zeros((CONFIG["IMG_SIZE"] * 2, CONFIG["IMG_SIZE"] * 2, 3), dtype=np.uint8)
                grid[:CONFIG["IMG_SIZE"], :CONFIG["IMG_SIZE"]] = np.array(Image.fromarray(truth_img.permute(1, 2, 0).numpy()))
                grid[:CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]:] = np.array(Image.fromarray(masked_input.permute(1, 2, 0).numpy()))
                grid[CONFIG["IMG_SIZE"]:, :CONFIG["IMG_SIZE"]] = np.array(Image.fromarray(output_img.permute(1, 2, 0).numpy()))
                grid[CONFIG["IMG_SIZE"]:, CONFIG["IMG_SIZE"]:] = np.array(Image.fromarray(keyframe_img.permute(1, 2, 0).numpy()))

                # Save the grid with confirmation
                output_dir = f"outputs/epoch_{epoch+1}"
                os.makedirs(output_dir, exist_ok=True)
                output_path = f"{output_dir}/fixed_sample.png"
                grid_pil = Image.fromarray(grid)
                grid_pil.save(output_path)
                print(f"Saved fixed sample to {output_path}")

    avg_val_loss = val_loss / len(val_loader)
    
    if new_best_keyframe_found:
        print(f"New keyframe found with score: {best_score:.2f}. Updating datasets.")
        train_dataset.update_keyframe(best_keyframe)
        val_dataset.update_keyframe(best_keyframe)

    print(f"Epoch {epoch+1}/{CONFIG['EPOCHS']} -> Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    torch.save(model.state_dict(), f"models/inpainting_model_epoch_{epoch+1}.pth")

print("Training finished.")

Using device: cuda
Loading and preprocessing video frames...
Loaded 1548 frames.
Finding initial best keyframe from first few seconds...
Initial keyframe found with sharpness score: 0.00 from first 5 seconds


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:01<00:00, 298MB/s]  


Epoch 1/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 1/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_1/fixed_sample.png
New keyframe found with score: 0.00. Updating datasets.
Epoch 1/50 -> Train Loss: 0.0303, Val Loss: 0.0227


Epoch 2/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 2/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_2/fixed_sample.png
Epoch 2/50 -> Train Loss: 0.0156, Val Loss: 0.0189


Epoch 3/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 3/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_3/fixed_sample.png
Epoch 3/50 -> Train Loss: 0.0136, Val Loss: 0.0172


Epoch 4/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 4/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_4/fixed_sample.png
Epoch 4/50 -> Train Loss: 0.0125, Val Loss: 0.0184


Epoch 5/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 5/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_5/fixed_sample.png
Epoch 5/50 -> Train Loss: 0.0119, Val Loss: 0.0161


Epoch 6/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 6/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_6/fixed_sample.png
Epoch 6/50 -> Train Loss: 0.0114, Val Loss: 0.0149


Epoch 7/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 7/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_7/fixed_sample.png
Epoch 7/50 -> Train Loss: 0.0110, Val Loss: 0.0140


Epoch 8/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 8/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_8/fixed_sample.png
Epoch 8/50 -> Train Loss: 0.0102, Val Loss: 0.0137


Epoch 9/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 9/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_9/fixed_sample.png
Epoch 9/50 -> Train Loss: 0.0095, Val Loss: 0.0135


Epoch 10/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 10/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_10/fixed_sample.png
Epoch 10/50 -> Train Loss: 0.0086, Val Loss: 0.0133


Epoch 11/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 11/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_11/fixed_sample.png
Epoch 11/50 -> Train Loss: 0.0083, Val Loss: 0.0125


Epoch 12/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 12/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_12/fixed_sample.png
Epoch 12/50 -> Train Loss: 0.0079, Val Loss: 0.0126


Epoch 13/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 13/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_13/fixed_sample.png
Epoch 13/50 -> Train Loss: 0.0075, Val Loss: 0.0112


Epoch 14/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 14/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_14/fixed_sample.png
Epoch 14/50 -> Train Loss: 0.0076, Val Loss: 0.0112


Epoch 15/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 15/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_15/fixed_sample.png
Epoch 15/50 -> Train Loss: 0.0073, Val Loss: 0.0111


Epoch 16/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 16/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_16/fixed_sample.png
Epoch 16/50 -> Train Loss: 0.0071, Val Loss: 0.0104


Epoch 17/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 17/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_17/fixed_sample.png
Epoch 17/50 -> Train Loss: 0.0070, Val Loss: 0.0100


Epoch 18/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 18/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_18/fixed_sample.png
Epoch 18/50 -> Train Loss: 0.0068, Val Loss: 0.0101


Epoch 19/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 19/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_19/fixed_sample.png
Epoch 19/50 -> Train Loss: 0.0066, Val Loss: 0.0103


Epoch 20/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 20/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_20/fixed_sample.png
Epoch 20/50 -> Train Loss: 0.0066, Val Loss: 0.0098


Epoch 21/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 21/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_21/fixed_sample.png
Epoch 21/50 -> Train Loss: 0.0065, Val Loss: 0.0096


Epoch 22/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 22/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_22/fixed_sample.png
Epoch 22/50 -> Train Loss: 0.0063, Val Loss: 0.0093


Epoch 23/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 23/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_23/fixed_sample.png
Epoch 23/50 -> Train Loss: 0.0062, Val Loss: 0.0095


Epoch 24/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 24/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_24/fixed_sample.png
Epoch 24/50 -> Train Loss: 0.0063, Val Loss: 0.0091


Epoch 25/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 25/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_25/fixed_sample.png
Epoch 25/50 -> Train Loss: 0.0060, Val Loss: 0.0094


Epoch 26/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 26/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_26/fixed_sample.png
Epoch 26/50 -> Train Loss: 0.0061, Val Loss: 0.0093


Epoch 27/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 27/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_27/fixed_sample.png
Epoch 27/50 -> Train Loss: 0.0059, Val Loss: 0.0091


Epoch 28/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 28/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_28/fixed_sample.png
Epoch 28/50 -> Train Loss: 0.0059, Val Loss: 0.0087


Epoch 29/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 29/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_29/fixed_sample.png
Epoch 29/50 -> Train Loss: 0.0058, Val Loss: 0.0087


Epoch 30/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 30/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_30/fixed_sample.png
Epoch 30/50 -> Train Loss: 0.0058, Val Loss: 0.0087


Epoch 31/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 31/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_31/fixed_sample.png
Epoch 31/50 -> Train Loss: 0.0058, Val Loss: 0.0087


Epoch 32/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 32/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_32/fixed_sample.png
Epoch 32/50 -> Train Loss: 0.0058, Val Loss: 0.0083


Epoch 33/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 33/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_33/fixed_sample.png
Epoch 33/50 -> Train Loss: 0.0057, Val Loss: 0.0084


Epoch 34/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 34/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_34/fixed_sample.png
Epoch 34/50 -> Train Loss: 0.0057, Val Loss: 0.0086


Epoch 35/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 35/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_35/fixed_sample.png
Epoch 35/50 -> Train Loss: 0.0057, Val Loss: 0.0082


Epoch 36/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 36/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_36/fixed_sample.png
Epoch 36/50 -> Train Loss: 0.0055, Val Loss: 0.0084


Epoch 37/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 37/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_37/fixed_sample.png
Epoch 37/50 -> Train Loss: 0.0055, Val Loss: 0.0082


Epoch 38/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 38/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_38/fixed_sample.png
Epoch 38/50 -> Train Loss: 0.0055, Val Loss: 0.0084


Epoch 39/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 39/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_39/fixed_sample.png
Epoch 39/50 -> Train Loss: 0.0054, Val Loss: 0.0084


Epoch 40/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 40/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_40/fixed_sample.png
Epoch 40/50 -> Train Loss: 0.0053, Val Loss: 0.0083


Epoch 41/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 41/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_41/fixed_sample.png
Epoch 41/50 -> Train Loss: 0.0055, Val Loss: 0.0082


Epoch 42/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 42/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_42/fixed_sample.png
Epoch 42/50 -> Train Loss: 0.0053, Val Loss: 0.0078


Epoch 43/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 43/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_43/fixed_sample.png
Epoch 43/50 -> Train Loss: 0.0053, Val Loss: 0.0079


Epoch 44/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 44/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_44/fixed_sample.png
Epoch 44/50 -> Train Loss: 0.0053, Val Loss: 0.0080


Epoch 45/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 45/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_45/fixed_sample.png
Epoch 45/50 -> Train Loss: 0.0054, Val Loss: 0.0083


Epoch 46/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 46/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_46/fixed_sample.png
Epoch 46/50 -> Train Loss: 0.0053, Val Loss: 0.0079


Epoch 47/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 47/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_47/fixed_sample.png
Epoch 47/50 -> Train Loss: 0.0052, Val Loss: 0.0079


Epoch 48/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 48/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_48/fixed_sample.png
Epoch 48/50 -> Train Loss: 0.0052, Val Loss: 0.0077


Epoch 49/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 49/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_49/fixed_sample.png
Epoch 49/50 -> Train Loss: 0.0051, Val Loss: 0.0076


Epoch 50/50 [T]:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch 50/50 [V]:   0%|          | 0/149 [00:00<?, ?it/s]

Saved fixed sample to outputs/epoch_50/fixed_sample.png
Epoch 50/50 -> Train Loss: 0.0052, Val Loss: 0.0077
Training finished.
