In [1]:
from network import R2AttU_Net
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # model = R2AttU_Net()
# # model = model.to(device)
# # sample_input = torch.randn(1, 3, 224, 224)

# # # Move the sample input to GPU if the model is on GPU
# # sample_input = sample_input.to(device)

# # # Forward pass through the model
# # output = model(sample_input)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from patchify import patchify, unpatchify  # Only to handle large images
import random
from PIL import Image
from datasets import Dataset

def load_and_pad_images_from_folder(folder, target_size=(1024, 1536)):
    images = []
    filenames = []
    for filename in sorted(os.listdir(folder)):  # Sort the filenames
        img_path = os.path.join(folder, filename)
        if os.path.isfile(img_path):
            img = Image.open(img_path).convert('RGB')  # Ensure image is in RGB
            img_array = np.array(img)
            # Pad the image
            pad_width = ((0, 0), (0, target_size[1] - img_array.shape[1]), (0, 0))
            padded_img = np.pad(img_array, pad_width, mode='constant', constant_values=0)
            images.append(padded_img)
            filenames.append(filename)
    return np.array(images), filenames

def load_and_pad_masks_from_folder(folder, target_size=(1024, 1536)):
    masks = []
    filenames = []
    for filename in sorted(os.listdir(folder)):  # Sort the filenames
        mask_path = os.path.join(folder, filename)
        if os.path.isfile(mask_path):
            mask = np.load(mask_path)
            # Pad the mask
            pad_width = ((0, 0), (0, target_size[1] - mask.shape[1]))
            padded_mask = np.pad(mask, pad_width, mode='constant', constant_values=0)
            masks.append(padded_mask)
            filenames.append(filename)
    return np.array(masks), filenames






# Folder paths
image_folder_path = r'/data1/sprasad/data/train/image3'
mask_folder_path = r'/data1/sprasad/data/train/label3'

# Load and pad images and masks
large_images, image_filenames = load_and_pad_images_from_folder(image_folder_path)
large_masks, mask_filenames = load_and_pad_masks_from_folder(mask_folder_path)

# Ensure the filenames are sorted and correspond
image_filenames = sorted(image_filenames)
mask_filenames = sorted(mask_filenames)

# Print shapes
print(f'Shape of large_images: {large_images.shape}')  # Should be (18, 1024, 1536, 3)
print(f'Shape of large_masks: {large_masks.shape}')    # Should be (18, 1024, 1536)

# Desired patch size for smaller images and step size.
patch_size = 256
step = 256
all_img_patches = []
all_mask_patches = []
all_patch_filenames = []

for img_idx in range(large_images.shape[0]):
    large_image = large_images[img_idx]
    patches_img = patchify(large_image, (patch_size, patch_size, 3), step=step)  # Step=256 for 256 patches means no overlap

    large_mask = large_masks[img_idx]
    patches_mask = patchify(large_mask, (patch_size, patch_size), step=step)  # Step=256 for 256 patches means no overlap

    for i in range(patches_img.shape[0]):
        for j in range(patches_img.shape[1]):
            single_patch_img = patches_img[i, j, :, :, :]
            single_patch_mask = patches_mask[i, j, :, :]

            all_img_patches.append(single_patch_img)
            all_mask_patches.append(single_patch_mask)

            # Generate a patch filename for tracking
            patch_filename = f"{image_filenames[img_idx].split('.')[0]}_patch_{i}_{j}.npy"
            all_patch_filenames.append(patch_filename)

images = np.array(all_img_patches)/255.0
masks = np.array(all_mask_patches)

# Create a list to store the indices of non-empty masks
valid_indices = [i for i, mask in enumerate(masks) if mask.max() != 0]

# Filter the image and mask arrays to keep only the non-empty pairs
filtered_images = images[valid_indices]
filtered_masks = masks[valid_indices]
filtered_filenames = [all_patch_filenames[i] for i in valid_indices]

print("Image shape:", filtered_images.shape)  # e.g., (num_frames, height, width, num_channels)
print("Mask shape:", filtered_masks.shape)

# Remove the extra dimension from images
filtered_images = np.squeeze(filtered_images, axis=1)
#filtered_images = filtered_images.astype(np.uint8)
#filtered_masks = filtered_masks.astype(np.uint8)


# # Convert the NumPy arrays to Pillow images and store them in a dictionary
# dataset_dict = {
#     "image": [Image.fromarray(img) for img in filtered_images],
#     "label": [Image.fromarray(mask) for mask in filtered_masks],
#     "filename": filtered_filenames
# }


# # Create the dataset using the datasets.Dataset class
# dataset = Dataset.from_dict(dataset_dict)

  from .autonotebook import tqdm as notebook_tqdm


Shape of large_images: (42, 1024, 1536, 3)
Shape of large_masks: (42, 1024, 1536)
Image shape: (310, 1, 256, 256, 3)
Mask shape: (310, 256, 256)


In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision import transforms

class SemanticSegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]

        if self.transform:
            # Apply transformations
            image = self.transform(image)
            #mask = self.transform(mask)
            mask=torch.tensor(mask, dtype=torch.long)

        return image, mask


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# # Convert images and masks to PyTorch tensors
# filtered_images = torch.tensor(filtered_images, dtype=torch.float32).permute(0, 3, 1, 2)  # Convert to (N, C, H, W)
# filtered_masks = torch.tensor(filtered_masks, dtype=torch.long)  # Masks typically have dtype long for segmentation


In [3]:
# Normalizing and transforming images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Create dataset instances
train_dataset = SemanticSegmentationDataset(filtered_images, filtered_masks, transform=transform)
test_dataset = SemanticSegmentationDataset(filtered_images, filtered_masks, transform=transform)  # Use separate data for testing in practice

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

In [5]:
filtered_images[0].transpose((0,1,2)).shape

(256, 256, 3)

## Sanity Check

In [None]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])




for i,l in train_loader:
    print(i.shape)
    print(type(i))
    i=i.squeeze(0)
    print(i.shape)
    i=i.permute(1,2,0)
    l=l.permute(1,2,0)
    print(i.shape)
    i=std*i +mean

    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.imshow(i)
    plt.subplot(1, 2, 2)
    plt.imshow(l, cmap='gray')

    plt.show()
    print(type(i),type(l))

    #print(np.min(i.numpy()))
      
    break

In [10]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = R2AttU_Net()
model = model.to(device)
sample_input = torch.randn(1, 3, 224, 224)

# Move the sample input to GPU if the model is on GPU
sample_input = sample_input.to(device)

# Forward pass through the model
output = model(sample_input)

In [11]:

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [12]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class FocalLoss(nn.Module):
#     def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha
#         self.gamma = gamma
#         self.logits = logits
#         self.reduce = reduce

#     def forward(self, inputs, targets):
#         inputs=inputs.squeeze(1)
#         if self.logits:
#             BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
#         else:
#             BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-BCE_loss)
#         F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

#         if self.reduce:
#             return torch.mean(F_loss)
#         else:
#             return F_loss

# # Usage
# criterion = FocalLoss(logits=True)
criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss with Logits


In [13]:
def jaccard_index(preds, targets, threshold=0.5):
    preds = (preds > threshold).float()
    preds = preds.bool()  # Convert to boolean for bitwise operations
    targets = targets.bool()  # Convert to boolean for bitwise operations

    intersection = (preds & targets).float().sum((1, 2))
    union = (preds | targets).float().sum((1, 2))
    iou = (intersection + 1e-10) / (union + 1e-10)  # avoid division by zero
    return iou.mean()  # return mean IoU over the batch


## Trained For 92 epochs

In [14]:
# Define the directory to save the model weights
model.load_state_dict(torch.load('/data1/sprasad/attntnadvUnet/attentionR2u/weights3/model_epoch_66.pth'))
save_dir = '/data1/sprasad/attntnadvUnet/attentionR2u/weights3'
os.makedirs(save_dir, exist_ok=True)

# Training loop
num_epochs = 133
optimizer = optim.Adam(model.parameters(), lr=1e-4)

model.train()



for epoch in range(num_epochs):
    model.train()
    epoch_losses = []
    for images, masks in tqdm(train_loader):
        # Move images and masks to device
        images=images.float()
        images = images.to(device)
        masks = masks.to(device)
        masks = masks.float()

        # Forward pass
        outputs = model(images)

        # Compute loss
        outputs=outputs.squeeze(1)
        loss = criterion(outputs, masks)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    # Save model after each epoch
    save_path = os.path.join(save_dir, f'model_epoch_{epoch+67}.pth')
    torch.save(model.state_dict(), save_path)

    # Compute mean loss for the epoch
    mean_loss = sum(epoch_losses) / len(epoch_losses)
    print(f'EPOCH: {epoch+67}')
    print(f'Mean loss: {mean_loss:.4f}')

    # Evaluate on validation set
    model.eval()
    with torch.no_grad():
        iou_scores = []
        for val_images, val_masks in test_loader:
            val_images=val_images.float()
            val_images = val_images.to(device)
            val_masks = val_masks.to(device)

            val_outputs = model(val_images)
            val_outputs = torch.sigmoid(val_outputs)  # Apply sigmoid if the model does not include it

            iou = jaccard_index(val_outputs, val_masks)
            iou_scores.append(iou.cpu().numpy()) 

        mean_iou = np.mean(iou_scores)
        print(f'Mean Jaccard Index: {mean_iou:.4f}')

100%|██████████| 310/310 [00:26<00:00, 11.79it/s]


EPOCH: 67
Mean loss: 0.0293
Mean Jaccard Index: 0.6805


100%|██████████| 310/310 [00:24<00:00, 12.67it/s]


EPOCH: 68
Mean loss: 0.0251
Mean Jaccard Index: 0.4846


100%|██████████| 310/310 [00:22<00:00, 13.87it/s]


EPOCH: 69
Mean loss: 0.0281
Mean Jaccard Index: 0.6413


100%|██████████| 310/310 [00:23<00:00, 13.08it/s]


EPOCH: 70
Mean loss: 0.0256
Mean Jaccard Index: 0.7420


100%|██████████| 310/310 [00:25<00:00, 11.99it/s]


EPOCH: 71
Mean loss: 0.0239
Mean Jaccard Index: 0.7554


100%|██████████| 310/310 [00:24<00:00, 12.75it/s]


EPOCH: 72
Mean loss: 0.0318
Mean Jaccard Index: 0.7660


100%|██████████| 310/310 [00:21<00:00, 14.35it/s]


EPOCH: 73
Mean loss: 0.0365
Mean Jaccard Index: 0.7722


100%|██████████| 310/310 [00:23<00:00, 13.41it/s]


EPOCH: 74
Mean loss: 0.0236
Mean Jaccard Index: 0.7723


100%|██████████| 310/310 [00:24<00:00, 12.60it/s]


EPOCH: 75
Mean loss: 0.0208
Mean Jaccard Index: 0.7906


100%|██████████| 310/310 [00:25<00:00, 12.31it/s]


EPOCH: 76
Mean loss: 0.0198
Mean Jaccard Index: 0.7752


100%|██████████| 310/310 [00:30<00:00, 10.29it/s]


EPOCH: 77
Mean loss: 0.0210
Mean Jaccard Index: 0.7637


100%|██████████| 310/310 [00:31<00:00,  9.95it/s]


EPOCH: 78
Mean loss: 0.0230
Mean Jaccard Index: 0.7761


100%|██████████| 310/310 [00:22<00:00, 13.73it/s]


EPOCH: 79
Mean loss: 0.0284
Mean Jaccard Index: 0.7660


100%|██████████| 310/310 [00:21<00:00, 14.48it/s]


EPOCH: 80
Mean loss: 0.0219
Mean Jaccard Index: 0.7695


100%|██████████| 310/310 [00:21<00:00, 14.72it/s]


EPOCH: 81
Mean loss: 0.0191
Mean Jaccard Index: 0.7602


100%|██████████| 310/310 [00:22<00:00, 13.83it/s]


EPOCH: 82
Mean loss: 0.0176
Mean Jaccard Index: 0.7791


100%|██████████| 310/310 [00:21<00:00, 14.72it/s]


EPOCH: 83
Mean loss: 0.0209
Mean Jaccard Index: 0.7305


100%|██████████| 310/310 [00:23<00:00, 13.28it/s]


EPOCH: 84
Mean loss: 0.0382
Mean Jaccard Index: 0.7545


100%|██████████| 310/310 [00:24<00:00, 12.42it/s]


EPOCH: 85
Mean loss: 0.0231
Mean Jaccard Index: 0.7081


100%|██████████| 310/310 [00:21<00:00, 14.14it/s]


EPOCH: 86
Mean loss: 0.0173
Mean Jaccard Index: 0.7680


100%|██████████| 310/310 [00:22<00:00, 13.68it/s]


EPOCH: 87
Mean loss: 0.0170
Mean Jaccard Index: 0.7863


100%|██████████| 310/310 [00:22<00:00, 13.51it/s]


EPOCH: 88
Mean loss: 0.0167
Mean Jaccard Index: 0.7744


100%|██████████| 310/310 [00:25<00:00, 12.20it/s]


EPOCH: 89
Mean loss: 0.0179
Mean Jaccard Index: 0.7324


100%|██████████| 310/310 [00:28<00:00, 11.05it/s]


EPOCH: 90
Mean loss: 0.0176
Mean Jaccard Index: 0.7796


100%|██████████| 310/310 [00:19<00:00, 15.90it/s]


EPOCH: 91
Mean loss: 0.0173
Mean Jaccard Index: 0.7800


100%|██████████| 310/310 [00:19<00:00, 15.99it/s]


EPOCH: 92
Mean loss: 0.0176
Mean Jaccard Index: 0.7260


100%|██████████| 310/310 [00:19<00:00, 15.80it/s]


EPOCH: 93
Mean loss: 0.0168


KeyboardInterrupt: 