In [1]:

# # Improved MoCo Pretraining and Segmentation for Kvasir-Instrument Dataset
#
# This notebook implements an improved version of MoCo pretraining followed by surgical instrument segmentation.
# Phase 1: MoCo Pretraining
# Phase 2: Segmentation Fine-tuning

# Install required packages
!pip install -q albumentations
!pip install -q tensorboard
!pip install -q segmentation-models-pytorch
!pip install -q monai

# %% [code]
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy import ndimage
import segmentation_models_pytorch as smp
from monai.losses import DiceCELoss, DiceFocalLoss
from monai.networks.nets import UNet
import random
import warnings
from collections import deque
warnings.filterwarnings('ignore')

# %% [code]
# Set random seeds for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

# Set up paths
DATA_DIR = '/kaggle/input/kvasir-dataset/kvasir-instrument'
RESULTS_DIR = '/kaggle/working/results'
os.makedirs(RESULTS_DIR, exist_ok=True)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')

# %% [markdown]
# ## 1. MoCo Implementation

# %% [code]
class MoCo(nn.Module):
    def __init__(self, dim=128, K=65536, m=0.999, T=0.07):
        """
        dim: feature dimension (default: 128)
        K: queue size (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super().__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        self.encoder_q = models.resnet50(weights=None)
        self.encoder_k = models.resnet50(weights=None)

        # remove the final fc layer
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, dim)
        )
        self.encoder_k.fc = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, dim)
        )

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """Momentum update of the key encoder"""
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        
        # replace the keys at ptr (dequeue and enqueue)
        if ptr + batch_size > self.K:
            self.queue[:, ptr:] = keys[:, :self.K-ptr].T
            self.queue[:, :batch_size-(self.K-ptr)] = keys[:, self.K-ptr:].T
            ptr = batch_size-(self.K-ptr)
        else:
            self.queue[:, ptr:ptr + batch_size] = keys.T
            ptr = (ptr + batch_size) % self.K

        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder
            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels

# %% [code]
class MoCoDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        
        # Load all image paths
        self.images = sorted([
            f for f in os.listdir(os.path.join(data_dir, 'images')) 
            if f.endswith(('.jpg', '.png'))
        ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.data_dir, 'images', img_name)
        
        # Read image
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            aug1 = self.transform(image=image)['image']
            aug2 = self.transform(image=image)['image']
            return aug1, aug2
        return image, image

def get_moco_augmentation():
    return A.Compose([
        A.RandomResizedCrop(224, 224),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
        A.GaussianBlur(blur_limit=(3, 7), p=0.5),
        A.GaussNoise(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

# %% [markdown]
# ## 2. MoCo Training

# %% [code]
def train_moco(data_dir, num_epochs=200):
    # Create dataset and dataloader
    dataset = MoCoDataset(data_dir, transform=get_moco_augmentation())
    dataloader = DataLoader(
        dataset, 
        batch_size=128,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True
    )
    
    # Create MoCo model
    model = MoCo().to(device)
    
    # Create optimizer
    optimizer = optim.Adam(model.encoder_q.parameters(), lr=1e-3)
    
    # Create scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # Create criterion
    criterion = nn.CrossEntropyLoss()
    
    # Initialize tensorboard
    writer = SummaryWriter(log_dir=os.path.join(RESULTS_DIR, 'moco_logs'))
    
    print(f'Dataset size: {len(dataset)}')
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        with tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
            for im_q, im_k in pbar:
                im_q, im_k = im_q.to(device), im_k.to(device)
                
                # Forward pass
                output, target = model(im_q, im_k)
                loss = criterion(output, target)
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update progress bar
                total_loss += loss.item()
                pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
        
        # Step scheduler
        scheduler.step()
        
        # Log metrics
        avg_loss = total_loss / len(dataloader)
        writer.add_scalar('Loss/train', avg_loss, epoch)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(RESULTS_DIR, f'moco_checkpoint_epoch_{epoch+1}.pth'))
    
    # Save final model
    torch.save({
        'epoch': num_epochs-1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(RESULTS_DIR, 'final_moco_model.pth'))
    
    writer.close()
    print('MoCo pretraining completed!')
    return model

# %% [markdown]
# ## 1. Improved Dataset and Augmentations

# %% [code]
class ImprovedKvasirDataset(Dataset):
    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        
        # Load all image paths
        self.images = sorted([f for f in os.listdir(os.path.join(data_dir, 'images')) 
                            if f.endswith(('.jpg', '.png'))])
        
        # Split dataset
        if split == 'train':
            self.images = self.images[:int(0.8 * len(self.images))]
        elif split == 'val':
            self.images = self.images[int(0.8 * len(self.images)):int(0.9 * len(self.images))]
        else:  # test
            self.images = self.images[int(0.9 * len(self.images)):]
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.data_dir, 'images', img_name)
        mask_path = os.path.join(self.data_dir, 'masks', img_name)
        
        # Read image and mask
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Normalize mask to binary
        mask = (mask > 127).astype(np.float32)
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        return image, mask

# Advanced augmentations
def get_training_augmentation():
    return A.Compose([
        A.RandomResizedCrop(256, 256, scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5),
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
            A.GridDistortion(p=0.5),
            A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
        ], p=0.3),
        A.OneOf([
            A.GaussNoise(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RandomGamma(p=0.5),
        ], p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_validation_augmentation():
    return A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

# %% [markdown]
# ## 2. Improved Model Architecture

# %% [code]
class ImprovedSegmentationModel(nn.Module):
    def __init__(self, moco_path=None):
        super().__init__()
        # Load MoCo pretrained model if available
        if moco_path and os.path.exists(moco_path):
            print("Loading MoCo pretrained weights...")
            moco_state = torch.load(moco_path)
            moco_model = MoCo()
            moco_model.load_state_dict(moco_state['model_state_dict'])
            encoder = moco_model.encoder_q
        else:
            print("Using ImageNet pretrained weights...")
            encoder = models.resnet50(weights='IMAGENET1K_V2')
        
        # Remove the final fc layer
        encoder = nn.Sequential(*list(encoder.children())[:-2])
        
        self.model = smp.UnetPlusPlus(
            encoder_name='resnet50',
            encoder_weights=None,  # We'll load our pretrained weights
            in_channels=3,
            classes=1,
            activation=None
        )
        
        # Load pretrained encoder weights
        if moco_path and os.path.exists(moco_path):
            self.model.encoder = encoder
        
    def forward(self, x):
        return self.model(x)

class DiceBCELoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # Dice loss
        dice = 1 - (2 * (pred * target).sum() + self.smooth) / \
               (pred.sum() + target.sum() + self.smooth)
        
        # BCE loss
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='mean')
        
        # Combined loss
        return dice + bce

class BoundaryLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.laplacian_kernel = torch.tensor([
            [-1, -1, -1],
            [-1,  8, -1],
            [-1, -1, -1]
        ], dtype=torch.float32).view(1, 1, 3, 3).to(device)
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        
        # Extract boundaries
        pred_boundary = F.conv2d(pred, self.laplacian_kernel, padding=1)
        target_boundary = F.conv2d(target, self.laplacian_kernel, padding=1)
        
        # Calculate boundary loss
        boundary_loss = F.mse_loss(pred_boundary, target_boundary)
        return boundary_loss

# %% [markdown]
# ## 3. Training Functions with Improvements

# %% [code]
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, epoch):
    model.train()
    epoch_loss = 0
    
    with tqdm(train_loader, desc=f'Training Epoch {epoch + 1}') as pbar:
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device).unsqueeze(1)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update progress bar
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': epoch_loss / (pbar.n + 1)})
    
    if scheduler is not None:
        scheduler.step(epoch_loss / len(train_loader))
    
    return epoch_loss / len(train_loader)

def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    dice_scores = []
    
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device).unsqueeze(1)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            
            # Calculate Dice score
            pred = torch.sigmoid(outputs) > 0.5
            dice = (2 * (pred * masks).sum()) / (pred.sum() + masks.sum() + 1e-8)
            dice_scores.append(dice.item())
    
    return val_loss / len(val_loader), np.mean(dice_scores)

def post_process_prediction(pred):
    # Convert to numpy
    pred = pred.cpu().numpy().squeeze()
    
    # Apply threshold
    binary = (pred > 0.5).astype(np.uint8)
    
    # Remove small objects
    binary = ndimage.binary_opening(binary, structure=np.ones((3,3)))
    
    # Remove small holes
    binary = ndimage.binary_closing(binary, structure=np.ones((3,3)))
    
    # Keep only the largest connected component
    labels, num_features = ndimage.label(binary)
    if num_features > 0:
        sizes = ndimage.sum(binary, labels, range(1, num_features + 1))
        max_label = sizes.argmax() + 1
        binary = (labels == max_label).astype(np.uint8)
    
    return torch.from_numpy(binary).float()

# %% [markdown]
# ## 4. Main Training Loop

# %% [code]
def main():
    # Create datasets
    train_dataset = ImprovedKvasirDataset(
        DATA_DIR,
        split='train',
        transform=get_training_augmentation()
    )
    
    val_dataset = ImprovedKvasirDataset(
        DATA_DIR,
        split='val',
        transform=get_validation_augmentation()
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    # Create model
    model = ImprovedSegmentationModel().to(device)
    
    # Initialize loss functions
    dice_bce_loss = DiceBCELoss()
    boundary_loss = BoundaryLoss()
    
    def combined_loss(pred, target):
        return dice_bce_loss(pred, target) + 0.5 * boundary_loss(pred, target)
    
    # Create optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        verbose=True
    )
    
    # Initialize tensorboard
    writer = SummaryWriter(log_dir=os.path.join(RESULTS_DIR, 'logs'))
    
    # Training loop
    num_epochs = 50
    best_dice = 0
    
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(
            model, train_loader, combined_loss,
            optimizer, scheduler, epoch
        )
        
        # Validate
        val_loss, dice_score = validate(model, val_loader, combined_loss)
        
        # Log metrics
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Dice/val', dice_score, epoch)
        
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')
        print(f'Dice Score: {dice_score:.4f}')
        
        # Save best model
        if dice_score > best_dice:
            best_dice = dice_score
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_dice': best_dice,
            }, os.path.join(RESULTS_DIR, 'best_model.pth'))
            print(f'Saved best model with Dice score: {best_dice:.4f}')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_dice': best_dice,
            }, os.path.join(RESULTS_DIR, f'checkpoint_epoch_{epoch+1}.pth'))
    
    writer.close()
    print('Training completed!')
    
    return model

# %% [markdown]
# ## 5. Visualization and Evaluation

# %% [code]
def visualize_predictions(model, dataset, num_samples=5):
    model.eval()
    test_loader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    with torch.no_grad():
        for i, (image, mask) in enumerate(test_loader):
            if i >= num_samples:
                break
                
            image = image.to(device)
            output = model(image)
            pred = torch.sigmoid(output)
            pred = post_process_prediction(pred)
            
            # Denormalize image
            image = image.cpu().squeeze().permute(1,2,0)
            image = image * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
            image = image.numpy()
            image = np.clip(image, 0, 1)
            
            # Plot
            axes[i, 0].imshow(image)
            axes[i, 0].set_title('Original Image')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(mask.squeeze(), cmap='gray')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred.squeeze(), cmap='gray')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# %% [markdown]
# ## 6. Run Training

# %% [code]
if __name__ == "__main__":
    # Phase 1: MoCo Pretraining
    print("=== Phase 1: MoCo Pretraining ===")
    moco_model = train_moco(DATA_DIR)
    
    # Phase 2: Segmentation Training
    print("\n=== Phase 2: Segmentation Training ===")
    model = main()
    
    # Create test dataset
    test_dataset = ImprovedKvasirDataset(
        DATA_DIR,
        split='test',
        transform=get_validation_augmentation()
    )
    
    # Visualize results
    visualize_predictions(model, test_dataset) 

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.3/121.3 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
[?25h

  check_for_updates()


Using device: cuda
GPU: Tesla P100-PCIE-16GB
Memory: 15.9 GB
=== Phase 1: MoCo Pretraining ===
Dataset size: 590


Epoch 1/200: 100%|██████████| 4/4 [00:09<00:00,  2.44s/it, loss=4.18]


Epoch 1/200, Loss: 4.1805


Epoch 2/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=6.34]


Epoch 2/200, Loss: 6.3402


Epoch 3/200: 100%|██████████| 4/4 [00:08<00:00,  2.07s/it, loss=6.62]


Epoch 3/200, Loss: 6.6168


Epoch 4/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=6.64]


Epoch 4/200, Loss: 6.6359


Epoch 5/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=6.64]


Epoch 5/200, Loss: 6.6433


Epoch 6/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=6.67]


Epoch 6/200, Loss: 6.6742


Epoch 7/200: 100%|██████████| 4/4 [00:08<00:00,  2.09s/it, loss=6.73]


Epoch 7/200, Loss: 6.7290


Epoch 8/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=6.78]


Epoch 8/200, Loss: 6.7774


Epoch 9/200: 100%|██████████| 4/4 [00:07<00:00,  1.89s/it, loss=6.83]


Epoch 9/200, Loss: 6.8254


Epoch 10/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=6.87]


Epoch 10/200, Loss: 6.8684


Epoch 11/200: 100%|██████████| 4/4 [00:08<00:00,  2.06s/it, loss=6.9]


Epoch 11/200, Loss: 6.9001


Epoch 12/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=6.93]


Epoch 12/200, Loss: 6.9308


Epoch 13/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=6.96]


Epoch 13/200, Loss: 6.9569


Epoch 14/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=6.99]


Epoch 14/200, Loss: 6.9921


Epoch 15/200: 100%|██████████| 4/4 [00:08<00:00,  2.06s/it, loss=7.02]


Epoch 15/200, Loss: 7.0161


Epoch 16/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.05]


Epoch 16/200, Loss: 7.0469


Epoch 17/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.08]


Epoch 17/200, Loss: 7.0751


Epoch 18/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.1]


Epoch 18/200, Loss: 7.1010


Epoch 19/200: 100%|██████████| 4/4 [00:08<00:00,  2.07s/it, loss=7.12]


Epoch 19/200, Loss: 7.1240


Epoch 20/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.15]


Epoch 20/200, Loss: 7.1479


Epoch 21/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.18]


Epoch 21/200, Loss: 7.1822


Epoch 22/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.21]


Epoch 22/200, Loss: 7.2115


Epoch 23/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.23]


Epoch 23/200, Loss: 7.2275


Epoch 24/200: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, loss=7.26]


Epoch 24/200, Loss: 7.2565


Epoch 25/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.29]


Epoch 25/200, Loss: 7.2853


Epoch 26/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.31]


Epoch 26/200, Loss: 7.3097


Epoch 27/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.33]


Epoch 27/200, Loss: 7.3343


Epoch 28/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.35]


Epoch 28/200, Loss: 7.3518


Epoch 29/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.38]


Epoch 29/200, Loss: 7.3786


Epoch 30/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.4]


Epoch 30/200, Loss: 7.4017


Epoch 31/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.43]


Epoch 31/200, Loss: 7.4263


Epoch 32/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.45]


Epoch 32/200, Loss: 7.4473


Epoch 33/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.47]


Epoch 33/200, Loss: 7.4726


Epoch 34/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.49]


Epoch 34/200, Loss: 7.4911


Epoch 35/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=7.5]


Epoch 35/200, Loss: 7.5042


Epoch 36/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.52]


Epoch 36/200, Loss: 7.5225


Epoch 37/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.54]


Epoch 37/200, Loss: 7.5401


Epoch 38/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.56]


Epoch 38/200, Loss: 7.5621


Epoch 39/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.58]


Epoch 39/200, Loss: 7.5800


Epoch 40/200: 100%|██████████| 4/4 [00:07<00:00,  1.88s/it, loss=7.59]


Epoch 40/200, Loss: 7.5916


Epoch 41/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.61]


Epoch 41/200, Loss: 7.6110


Epoch 42/200: 100%|██████████| 4/4 [00:07<00:00,  2.00s/it, loss=7.62]


Epoch 42/200, Loss: 7.6196


Epoch 43/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.64]


Epoch 43/200, Loss: 7.6437


Epoch 44/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.66]


Epoch 44/200, Loss: 7.6625


Epoch 45/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.67]


Epoch 45/200, Loss: 7.6720


Epoch 46/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=7.68]


Epoch 46/200, Loss: 7.6834


Epoch 47/200: 100%|██████████| 4/4 [00:08<00:00,  2.02s/it, loss=7.69]


Epoch 47/200, Loss: 7.6899


Epoch 48/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.69]


Epoch 48/200, Loss: 7.6922


Epoch 49/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.7]


Epoch 49/200, Loss: 7.7022


Epoch 50/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.71]


Epoch 50/200, Loss: 7.7093


Epoch 51/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.72]


Epoch 51/200, Loss: 7.7180


Epoch 52/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.73]


Epoch 52/200, Loss: 7.7322


Epoch 53/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=7.74]


Epoch 53/200, Loss: 7.7363


Epoch 54/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.76]


Epoch 54/200, Loss: 7.7556


Epoch 55/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.76]


Epoch 55/200, Loss: 7.7631


Epoch 56/200: 100%|██████████| 4/4 [00:08<00:00,  2.02s/it, loss=7.77]


Epoch 56/200, Loss: 7.7655


Epoch 57/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.77]


Epoch 57/200, Loss: 7.7728


Epoch 58/200: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it, loss=7.78]


Epoch 58/200, Loss: 7.7775


Epoch 59/200: 100%|██████████| 4/4 [00:08<00:00,  2.19s/it, loss=7.78]


Epoch 59/200, Loss: 7.7812


Epoch 60/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.79]


Epoch 60/200, Loss: 7.7932


Epoch 61/200: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it, loss=7.8]


Epoch 61/200, Loss: 7.8011


Epoch 62/200: 100%|██████████| 4/4 [00:08<00:00,  2.07s/it, loss=7.81]


Epoch 62/200, Loss: 7.8141


Epoch 63/200: 100%|██████████| 4/4 [00:08<00:00,  2.15s/it, loss=7.82]


Epoch 63/200, Loss: 7.8229


Epoch 64/200: 100%|██████████| 4/4 [00:08<00:00,  2.03s/it, loss=7.8]


Epoch 64/200, Loss: 7.8003


Epoch 65/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.82]


Epoch 65/200, Loss: 7.8238


Epoch 66/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=7.81]


Epoch 66/200, Loss: 7.8103


Epoch 67/200: 100%|██████████| 4/4 [00:08<00:00,  2.12s/it, loss=7.82]


Epoch 67/200, Loss: 7.8167


Epoch 68/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.84]


Epoch 68/200, Loss: 7.8392


Epoch 69/200: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it, loss=7.84]


Epoch 69/200, Loss: 7.8353


Epoch 70/200: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it, loss=7.83]


Epoch 70/200, Loss: 7.8346


Epoch 71/200: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it, loss=7.85]


Epoch 71/200, Loss: 7.8505


Epoch 72/200: 100%|██████████| 4/4 [00:07<00:00,  2.00s/it, loss=7.86]


Epoch 72/200, Loss: 7.8600


Epoch 73/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.85]


Epoch 73/200, Loss: 7.8477


Epoch 74/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=7.86]


Epoch 74/200, Loss: 7.8562


Epoch 75/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.85]


Epoch 75/200, Loss: 7.8531


Epoch 76/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.85]


Epoch 76/200, Loss: 7.8519


Epoch 77/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.84]


Epoch 77/200, Loss: 7.8383


Epoch 78/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.86]


Epoch 78/200, Loss: 7.8638


Epoch 79/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.85]


Epoch 79/200, Loss: 7.8500


Epoch 80/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.88]


Epoch 80/200, Loss: 7.8840


Epoch 81/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.89]


Epoch 81/200, Loss: 7.8880


Epoch 82/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.86]


Epoch 82/200, Loss: 7.8579


Epoch 83/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.87]


Epoch 83/200, Loss: 7.8665


Epoch 84/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.85]


Epoch 84/200, Loss: 7.8452


Epoch 85/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=7.84]


Epoch 85/200, Loss: 7.8370


Epoch 86/200: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, loss=7.86]


Epoch 86/200, Loss: 7.8552


Epoch 87/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.85]


Epoch 87/200, Loss: 7.8503


Epoch 88/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.86]


Epoch 88/200, Loss: 7.8624


Epoch 89/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.89]


Epoch 89/200, Loss: 7.8873


Epoch 90/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.88]


Epoch 90/200, Loss: 7.8758


Epoch 91/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.85]


Epoch 91/200, Loss: 7.8522


Epoch 92/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.85]


Epoch 92/200, Loss: 7.8539


Epoch 93/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.83]


Epoch 93/200, Loss: 7.8300


Epoch 94/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.88]


Epoch 94/200, Loss: 7.8824


Epoch 95/200: 100%|██████████| 4/4 [00:08<00:00,  2.03s/it, loss=7.86]


Epoch 95/200, Loss: 7.8639


Epoch 96/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.84]


Epoch 96/200, Loss: 7.8419


Epoch 97/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.83]


Epoch 97/200, Loss: 7.8279


Epoch 98/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.86]


Epoch 98/200, Loss: 7.8599


Epoch 99/200: 100%|██████████| 4/4 [00:08<00:00,  2.10s/it, loss=7.85]


Epoch 99/200, Loss: 7.8502


Epoch 100/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.85]


Epoch 100/200, Loss: 7.8547


Epoch 101/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=7.88]


Epoch 101/200, Loss: 7.8816


Epoch 102/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.85]


Epoch 102/200, Loss: 7.8534


Epoch 103/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.86]


Epoch 103/200, Loss: 7.8554


Epoch 104/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.88]


Epoch 104/200, Loss: 7.8798


Epoch 105/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=7.87]


Epoch 105/200, Loss: 7.8674


Epoch 106/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.83]


Epoch 106/200, Loss: 7.8306


Epoch 107/200: 100%|██████████| 4/4 [00:08<00:00,  2.02s/it, loss=7.86]


Epoch 107/200, Loss: 7.8604


Epoch 108/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.85]


Epoch 108/200, Loss: 7.8538


Epoch 109/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.8]


Epoch 109/200, Loss: 7.8045


Epoch 110/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.85]


Epoch 110/200, Loss: 7.8497


Epoch 111/200: 100%|██████████| 4/4 [00:08<00:00,  2.03s/it, loss=7.87]


Epoch 111/200, Loss: 7.8689


Epoch 112/200: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, loss=7.86]


Epoch 112/200, Loss: 7.8643


Epoch 113/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=7.83]


Epoch 113/200, Loss: 7.8318


Epoch 114/200: 100%|██████████| 4/4 [00:07<00:00,  1.89s/it, loss=7.85]


Epoch 114/200, Loss: 7.8527


Epoch 115/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.85]


Epoch 115/200, Loss: 7.8538


Epoch 116/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.87]


Epoch 116/200, Loss: 7.8683


Epoch 117/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.82]


Epoch 117/200, Loss: 7.8227


Epoch 118/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.81]


Epoch 118/200, Loss: 7.8122


Epoch 119/200: 100%|██████████| 4/4 [00:08<00:00,  2.05s/it, loss=7.85]


Epoch 119/200, Loss: 7.8525


Epoch 120/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.8]


Epoch 120/200, Loss: 7.8043


Epoch 121/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.82]


Epoch 121/200, Loss: 7.8212


Epoch 122/200: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, loss=7.84]


Epoch 122/200, Loss: 7.8430


Epoch 123/200: 100%|██████████| 4/4 [00:08<00:00,  2.06s/it, loss=7.81]


Epoch 123/200, Loss: 7.8115


Epoch 124/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.8]


Epoch 124/200, Loss: 7.7954


Epoch 125/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.78]


Epoch 125/200, Loss: 7.7811


Epoch 126/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.78]


Epoch 126/200, Loss: 7.7819


Epoch 127/200: 100%|██████████| 4/4 [00:08<00:00,  2.07s/it, loss=7.75]


Epoch 127/200, Loss: 7.7522


Epoch 128/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.77]


Epoch 128/200, Loss: 7.7732


Epoch 129/200: 100%|██████████| 4/4 [00:07<00:00,  1.89s/it, loss=7.74]


Epoch 129/200, Loss: 7.7372


Epoch 130/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.71]


Epoch 130/200, Loss: 7.7145


Epoch 131/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=7.71]


Epoch 131/200, Loss: 7.7056


Epoch 132/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.72]


Epoch 132/200, Loss: 7.7188


Epoch 133/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.71]


Epoch 133/200, Loss: 7.7146


Epoch 134/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.7]


Epoch 134/200, Loss: 7.7049


Epoch 135/200: 100%|██████████| 4/4 [00:08<00:00,  2.06s/it, loss=7.69]


Epoch 135/200, Loss: 7.6867


Epoch 136/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=7.7]


Epoch 136/200, Loss: 7.6982


Epoch 137/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=7.62]


Epoch 137/200, Loss: 7.6237


Epoch 138/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.63]


Epoch 138/200, Loss: 7.6299


Epoch 139/200: 100%|██████████| 4/4 [00:08<00:00,  2.04s/it, loss=7.64]


Epoch 139/200, Loss: 7.6449


Epoch 140/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.63]


Epoch 140/200, Loss: 7.6320


Epoch 141/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=7.62]


Epoch 141/200, Loss: 7.6226


Epoch 142/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=7.67]


Epoch 142/200, Loss: 7.6650


Epoch 143/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.65]


Epoch 143/200, Loss: 7.6450


Epoch 144/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.64]


Epoch 144/200, Loss: 7.6420


Epoch 145/200: 100%|██████████| 4/4 [00:07<00:00,  2.00s/it, loss=7.6]


Epoch 145/200, Loss: 7.6011


Epoch 146/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.65]


Epoch 146/200, Loss: 7.6492


Epoch 147/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.66]


Epoch 147/200, Loss: 7.6615


Epoch 148/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.65]


Epoch 148/200, Loss: 7.6455


Epoch 149/200: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, loss=7.62]


Epoch 149/200, Loss: 7.6180


Epoch 150/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.59]


Epoch 150/200, Loss: 7.5927


Epoch 151/200: 100%|██████████| 4/4 [00:07<00:00,  2.00s/it, loss=7.63]


Epoch 151/200, Loss: 7.6255


Epoch 152/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.6]


Epoch 152/200, Loss: 7.5952


Epoch 153/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.58]


Epoch 153/200, Loss: 7.5841


Epoch 154/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.57]


Epoch 154/200, Loss: 7.5742


Epoch 155/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.63]


Epoch 155/200, Loss: 7.6330


Epoch 156/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.61]


Epoch 156/200, Loss: 7.6068


Epoch 157/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.59]


Epoch 157/200, Loss: 7.5932


Epoch 158/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.65]


Epoch 158/200, Loss: 7.6510


Epoch 159/200: 100%|██████████| 4/4 [00:08<00:00,  2.03s/it, loss=7.65]


Epoch 159/200, Loss: 7.6515


Epoch 160/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.66]


Epoch 160/200, Loss: 7.6603


Epoch 161/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.58]


Epoch 161/200, Loss: 7.5831


Epoch 162/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.58]


Epoch 162/200, Loss: 7.5829


Epoch 163/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.6]


Epoch 163/200, Loss: 7.5983


Epoch 164/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.65]


Epoch 164/200, Loss: 7.6530


Epoch 165/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.6]


Epoch 165/200, Loss: 7.5989


Epoch 166/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.58]


Epoch 166/200, Loss: 7.5760


Epoch 167/200: 100%|██████████| 4/4 [00:07<00:00,  1.98s/it, loss=7.63]


Epoch 167/200, Loss: 7.6285


Epoch 168/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.56]


Epoch 168/200, Loss: 7.5604


Epoch 169/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.59]


Epoch 169/200, Loss: 7.5881


Epoch 170/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.61]


Epoch 170/200, Loss: 7.6146


Epoch 171/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=7.57]


Epoch 171/200, Loss: 7.5746


Epoch 172/200: 100%|██████████| 4/4 [00:07<00:00,  1.92s/it, loss=7.64]


Epoch 172/200, Loss: 7.6387


Epoch 173/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.62]


Epoch 173/200, Loss: 7.6220


Epoch 174/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.6]


Epoch 174/200, Loss: 7.6009


Epoch 175/200: 100%|██████████| 4/4 [00:07<00:00,  1.99s/it, loss=7.6]


Epoch 175/200, Loss: 7.6008


Epoch 176/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.63]


Epoch 176/200, Loss: 7.6342


Epoch 177/200: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it, loss=7.64]


Epoch 177/200, Loss: 7.6353


Epoch 178/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.59]


Epoch 178/200, Loss: 7.5898


Epoch 179/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.62]


Epoch 179/200, Loss: 7.6195


Epoch 180/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.6]


Epoch 180/200, Loss: 7.5989


Epoch 181/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.6]


Epoch 181/200, Loss: 7.6000


Epoch 182/200: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, loss=7.68]


Epoch 182/200, Loss: 7.6821


Epoch 183/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.59]


Epoch 183/200, Loss: 7.5931


Epoch 184/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.66]


Epoch 184/200, Loss: 7.6644


Epoch 185/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.61]


Epoch 185/200, Loss: 7.6138


Epoch 186/200: 100%|██████████| 4/4 [00:08<00:00,  2.01s/it, loss=7.66]


Epoch 186/200, Loss: 7.6618


Epoch 187/200: 100%|██████████| 4/4 [00:08<00:00,  2.02s/it, loss=7.65]


Epoch 187/200, Loss: 7.6486


Epoch 188/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.69]


Epoch 188/200, Loss: 7.6853


Epoch 189/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.63]


Epoch 189/200, Loss: 7.6272


Epoch 190/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.64]


Epoch 190/200, Loss: 7.6393


Epoch 191/200: 100%|██████████| 4/4 [00:08<00:00,  2.00s/it, loss=7.64]


Epoch 191/200, Loss: 7.6440


Epoch 192/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.68]


Epoch 192/200, Loss: 7.6822


Epoch 193/200: 100%|██████████| 4/4 [00:07<00:00,  1.90s/it, loss=7.67]


Epoch 193/200, Loss: 7.6651


Epoch 194/200: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it, loss=7.67]


Epoch 194/200, Loss: 7.6722


Epoch 195/200: 100%|██████████| 4/4 [00:08<00:00,  2.02s/it, loss=7.7]


Epoch 195/200, Loss: 7.6966


Epoch 196/200: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it, loss=7.71]


Epoch 196/200, Loss: 7.7096


Epoch 197/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.74]


Epoch 197/200, Loss: 7.7360


Epoch 198/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.73]


Epoch 198/200, Loss: 7.7329


Epoch 199/200: 100%|██████████| 4/4 [00:07<00:00,  1.96s/it, loss=7.74]


Epoch 199/200, Loss: 7.7444


Epoch 200/200: 100%|██████████| 4/4 [00:07<00:00,  1.97s/it, loss=7.73]


Epoch 200/200, Loss: 7.7260
MoCo pretraining completed!

=== Phase 2: Segmentation Training ===
Using ImageNet pretrained weights...


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 88.2MB/s]
Training Epoch 1:   0%|          | 0/30 [00:00<?, ?it/s]


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-1-2a9ba73418c2>", line 318, in __getitem__
    mask = (mask > 127).astype(np.float32)
TypeError: '>' not supported between instances of 'NoneType' and 'int'
