In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [12]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np

class TiledSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, tile_size=(512, 512), overlap=51, grayscale=True, target_size=(3840, 2160)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.tile_size = tile_size
        self.overlap = overlap
        self.grayscale = grayscale
        self.target_size = target_size
        self.images = [f.split('.')[0] for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
        self.tile_info = self._prepare_tile_info()  # Precompute all (image, tile_box) pairs

    def _prepare_tile_info(self):
        # Precompute tile boxes for each image to handle overlap and avoid recalculating in __getitem__
        tile_info = []
        tile_width, tile_height = self.tile_size

        for img_name in self.images:
            img_path = os.path.join(self.image_dir, f"{img_name}.jpg")
            image = Image.open(img_path).resize(self.target_size, Image.BICUBIC)
            img_width, img_height = image.size

            for i in range(0, img_width - tile_width + 1, tile_width - self.overlap):
                for j in range(0, img_height - tile_height + 1, tile_height - self.overlap):
                    tile_info.append((img_name, (i, j, i + tile_width, j + tile_height)))
        return tile_info

    def __len__(self):
        return len(self.tile_info)

    def __getitem__(self, idx):
        img_name, box = self.tile_info[idx]
        img_path = os.path.join(self.image_dir, f"{img_name}.jpg")
        mask_path = os.path.join(self.mask_dir, f"{img_name}_lab.png")

        image = Image.open(img_path).resize(self.target_size, Image.BICUBIC)
        mask = Image.open(mask_path).resize(self.target_size, Image.NEAREST).convert("L")

        if self.grayscale:
            image = image.convert("L")
        else:
            image = image.convert("RGB")

        # Crop image and mask
        img_tile = image.crop(box)
        mask_tile = mask.crop(box)

        # Apply transforms
        if self.transform:
            img_tile = self.transform(img_tile)
        else:
            img_tile = transforms.ToTensor()(img_tile)

        mask_tile = torch.tensor(np.array(mask_tile), dtype=torch.long)

        return img_tile, mask_tile  # Now returns individual (image tile, mask tile) pairs




In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

# Paths to each subset
target_dir = '/content/drive/MyDrive/RescueNet_dataset/RescueNet'

train_image_dir = os.path.join(target_dir, 'train/train-org-img')
train_mask_dir = os.path.join(target_dir, 'train/train-label-img')

val_image_dir = os.path.join(target_dir, 'val/val-org-img')
val_mask_dir = os.path.join(target_dir, 'val/val-label-img')

test_image_dir = os.path.join(target_dir, 'test/test-org-img')
test_mask_dir = os.path.join(target_dir, 'test/test-label-img')

grayscale = False
tile_size = (512, 512)
batch_size = 8
overlap = 51  # 10% overlap for 512 x 512 tiles
target_size = (3840, 2160)  # Resize images to 3840 x 2160

# Example transformations for image tiles
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]) if grayscale else transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Instantiate the dataset
train_dataset = TiledSegmentationDataset(
    image_dir=train_image_dir,
    mask_dir=train_mask_dir,
    transform=transform,
    tile_size=tile_size,
    overlap=overlap,
    grayscale=grayscale,
    target_size=target_size
)

val_dataset = TiledSegmentationDataset(
    image_dir=val_image_dir,
    mask_dir=val_mask_dir,
    transform=transform,
    tile_size=tile_size,
    overlap=overlap,
    grayscale=grayscale,
    target_size=target_size
)

test_dataset = TiledSegmentationDataset(
    image_dir=test_image_dir,
    mask_dir=test_mask_dir,
    transform=transform,
    tile_size=tile_size,
    overlap=overlap,
    grayscale=grayscale,
    target_size=target_size
)

# Wrap each dataset in a DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class JPU(nn.Module):
    def __init__(self, in_channels, width=512):
        super(JPU, self).__init__()
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels[0], width, 3, padding=1, bias=False),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels[1], width, 3, padding=1, bias=False),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels[2], width, 3, padding=1, bias=False),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace=True)
        )

    def forward(self, *inputs):
        feats = [self.conv5(inputs[0]), self.conv4(inputs[1]), self.conv3(inputs[2])]
        size = feats[-1].shape[2:]
        feats = [F.interpolate(f, size=size, mode='bilinear', align_corners=True) for f in feats]
        feat = torch.cat(feats, dim=1)
        return feat

class FastFCN(nn.Module):
    def __init__(self, num_classes=12):
        super(FastFCN, self).__init__()

        # Load a ResNet-50 backbone pretrained on ImageNet
        backbone = models.resnet50(pretrained=True)

        # Extract feature layers from ResNet
        self.layer0 = nn.Sequential(
            backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool)
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        # JPU module with correct input channels
        self.jpu = JPU([2048, 1024, 512], width=512)  # Adjusted to match actual channels

        # Segmentation Head
        self.seg_head = nn.Sequential(
            nn.Conv2d(3 * 512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )

    def forward(self, x):
        # Extract features from backbone
        x0 = self.layer0(x)
        # print("Layer 0 output:", x0.shape)
        x1 = self.layer1(x0)
        # print("Layer 1 output:", x1.shape)
        x2 = self.layer2(x1)
        # print("Layer 2 output:", x2.shape)
        x3 = self.layer3(x2)
        # print("Layer 3 output:", x3.shape)
        x4 = self.layer4(x3)
        # print("Layer 4 output:", x4.shape)

        # Apply JPU
        jpu_feat = self.jpu(x4, x3, x2)
        # print("JPU output:", jpu_feat.shape)

        # Apply segmentation head
        out = self.seg_head(jpu_feat)
        # print("Segmentation head output:", out.shape)

        # Upsample to the input image size
        out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=True)
        return out

# Instantiate model with 12 classes
model = FastFCN(num_classes=12)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [26]:
import time
import torch

def train(model, dataloader, criterion, optimizer, num_epochs, device):
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        epoch_start_time = time.time()

        print(f"\nEpoch [{epoch + 1}/{num_epochs}]")

        for batch_idx, (images, masks) in enumerate(dataloader):
            batch_start_time = time.time()

            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update running loss
            running_loss += loss.item() * images.size(0)

            # Calculate time per batch
            batch_time = time.time() - batch_start_time
            print(f"Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}, Time: {batch_time:.2f} seconds")

        # Calculate epoch loss and time
        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_time = time.time() - epoch_start_time
        avg_batch_time = epoch_time / len(dataloader)

        print(f"Epoch Loss: {epoch_loss:.4f}")
        print(f"Epoch Time: {epoch_time:.2f} seconds, Avg Time per Batch: {avg_batch_time:.2f} seconds")

        # Estimate remaining time based on current epoch time
        remaining_time = epoch_time * (num_epochs - epoch - 1)
        print(f"Estimated Time Remaining: {remaining_time // 60:.0f}m {remaining_time % 60:.0f}s")

num_epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

train(model, train_loader, criterion, optimizer, num_epochs=num_epochs, device=device)


Using: cuda

Epoch [1/2]
Batch 1/940, Loss: 2.7324, Time: 0.72 seconds
Batch 2/940, Loss: 2.8195, Time: 0.74 seconds
Batch 3/940, Loss: 1.7334, Time: 0.73 seconds
Batch 4/940, Loss: 1.2127, Time: 0.75 seconds
Batch 5/940, Loss: 1.3647, Time: 0.74 seconds
Batch 6/940, Loss: 1.2291, Time: 0.74 seconds
Batch 7/940, Loss: 2.9021, Time: 0.76 seconds
Batch 8/940, Loss: 2.1381, Time: 0.73 seconds
Batch 9/940, Loss: 1.2285, Time: 0.73 seconds
Batch 10/940, Loss: 2.0755, Time: 0.72 seconds
Batch 11/940, Loss: 1.4923, Time: 0.74 seconds
Batch 12/940, Loss: 1.4276, Time: 0.71 seconds
Batch 13/940, Loss: 0.8928, Time: 0.71 seconds
Batch 14/940, Loss: 0.6425, Time: 0.74 seconds
Batch 15/940, Loss: 2.1208, Time: 0.71 seconds
Batch 16/940, Loss: 2.0969, Time: 0.75 seconds
Batch 17/940, Loss: 1.2724, Time: 0.72 seconds
Batch 18/940, Loss: 1.0774, Time: 0.75 seconds
Batch 19/940, Loss: 0.8766, Time: 0.74 seconds
Batch 20/940, Loss: 0.8442, Time: 0.74 seconds
Batch 21/940, Loss: 1.2474, Time: 0.71 secon

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-26-0ee552ce2924>", line 52, in <cell line: 52>
    train(model, train_loader, criterion, optimizer, num_epochs=num_epochs, device=device)
  File "<ipython-input-26-0ee552ce2924>", line 14, in train
    for batch_idx, (images, masks) in enumerate(dataloader):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.10/dist-packages

TypeError: object of type 'NoneType' has no len()