In [64]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split
import random
import numpy as np
from tqdm import tqdm
import argparse
import wandb
from os.path import splitext
from os import listdir
import numpy as np
import os
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
from torchvision.transforms import functional as TF
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import imageio.v2 as imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision import models
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import glob
import wandb
import random
import numpy as np
from model import *
from utils import *

In [77]:

class ResNetFeatures(nn.Module):
    def __init__(self, output_size):
        super(ResNetFeatures, self).__init__()
        resnet = models.resnet50(pretrained=False)
        resnet.fc = torch.nn.Linear(2048,19)
#         resnet.conv1 = torch.nn.Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # Load your pretrained weights here if you have them
        checkpoint = torch.load('../../models/PyTorch/B3_rn50_moco_0099_ckpt.pth')

        # rename moco pre-trained keys
        state_dict = checkpoint['state_dict']
        #print(state_dict.keys())
        for k in list(state_dict.keys()):
            # retain only encoder up to before the embedding layer
            if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                #pdb.set_trace()
                # remove prefix
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            # delete renamed or unused k
            del state_dict[k]
        
        '''
        # remove prefix
        state_dict = {k.replace("module.", ""): v for k,v in state_dict.items()}
        '''
        #args.start_epoch = 0
        resnet.load_state_dict(state_dict, strict=False)

        # Remove the fully connected layer and the average pooling layer
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d(output_size)
        
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        return x

class FusionNet(nn.Module):
    def __init__(self, input_channels, output_size):
        super(FusionNet, self).__init__()
        self.conv = nn.Conv2d(input_channels, 1, kernel_size=1)  # Reduce to 1 channel
        self.upsample = nn.Upsample(size=output_size, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        return x


class RGB_DEM_to_SO(nn.Module):
    def __init__(self, resnet_output_size, fusion_output_size):
        super(RGB_DEM_to_SO, self).__init__()
        self.resnet = ResNetFeatures(output_size=resnet_output_size)
        self.fusion_net = FusionNet(input_channels=6*2048, output_size=fusion_output_size)
        self.unet = UNet_1(n_channels=2, n_classes=8)

    def forward(self, dem, rgbs):
        # rgbs is a list of RGB images
        features = [self.resnet(rgb) for rgb in rgbs]
        features = torch.cat(features, dim=1)  # Concatenate features along the channel dimension
        fused = self.fusion_net(features)

        # Concatenate DEM and fused features
        combined_input = torch.cat((dem, fused), dim=1)
        so_output = self.unet(combined_input)

        return so_output

In [78]:
class RGB_RasterTilesDataset(Dataset):
    def __init__(self, dem_dir, so_dir, rgb_dir, transform=None):
        """
        Custom dataset to load DEM, SO, and RGB tiles.

        :param dem_dir: Directory where DEM tiles are stored.
        :param so_dir: Directory where SO tiles are stored.
        :param rgb_dir: Directory where RGB tiles are stored.
        :param transform: Optional transform to be applied on a sample.
        """
        self.dem_dir = dem_dir
        self.so_dir = so_dir
        self.rgb_dir = rgb_dir
        self.transform = transform

        # self.filenames = [f for f in os.listdir(dem_dir) if os.path.isfile(os.path.join(dem_dir, f))]
        self.tile_identifiers = [f.split('_')[2:4] for f in os.listdir(dem_dir) if 'dem_tile' in f]

    def __len__(self):
        return len(self.tile_identifiers)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        tile_id = self.tile_identifiers[idx]
        dem_file = os.path.join(self.dem_dir, f'dem_tile_{tile_id[0]}_{tile_id[1]}')
        so_file = os.path.join(self.so_dir, f'so_tile_{tile_id[0]}_{tile_id[1]}')

        # dem_file = os.path.join(self.dem_dir, self.filenames[idx])
        # so_file = os.path.join(self.so_dir, self.filenames[idx])
        # Assuming RGB tiles follow a similar naming convention
        rgb_files = [os.path.join(self.rgb_dir, f'rgb{k}_tile_{tile_id[0]}_{tile_id[1]}') for k in range(6)]

        dem_image = Image.open(dem_file)
        so_image = Image.open(so_file)
        rgb_images = [imageio.imread(file) for file in rgb_files]

        dem_array = np.array(dem_image)
        so_array = np.array(so_image)
        rgb_arrays = [np.array(image).transpose(2,0,1)/255 for image in rgb_images]

        sample = {'DEM': dem_array, 'SO': so_array, 'RGB': rgb_arrays}

        if self.transform:
            sample = self.transform(sample)

        return sample
    

class RGB_RasterTransform:
    """
    A custom transform class for raster data.
    """
    def __init__(self):
        pass

    def __call__(self, sample):
        dem, so, rgb = sample['DEM'], sample['SO'], sample['RGB']

        # Random horizontal flipping
        # if torch.rand(1) > 0.5:
        #     dem = TF.hflip(dem)
        #     so = TF.hflip(so)

        # # Random vertical flipping
        # if torch.rand(1) > 0.5:
        #     dem = TF.vflip(dem)
        #     so = TF.vflip(so)

        # Convert numpy arrays to tensors
        dem = TF.to_tensor(dem)
        so = TF.to_tensor(so)
        rgb_images = [TF.to_tensor(image) for image in rgb]

        dem = TF.normalize(dem, 318.90567, 16.467052)

        so = so.long()

        return {'DEM': dem, 'SO': so.squeeze(), 'RGB': rgb}

In [79]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [80]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

dem_dir = '/home/macula/SMATousi/Gullies/ground_truth/google_api/training_process/DEM2SO/dem2so/dem_with_rgb/dem'
so_dir = '/home/macula/SMATousi/Gullies/ground_truth/google_api/training_process/DEM2SO/dem2so/dem_with_rgb/so'
rgb_dir = '/home/macula/SMATousi/Gullies/ground_truth/google_api/training_process/DEM2SO/dem2so/dem_with_rgb/rgb'


batch_size = 4
learning_rate = 0.0001
epochs = 10
number_of_workers = 1
image_size = 128
val_percent = 0.1

cuda:0


In [81]:
transform = RGB_RasterTransform()

dataset = RGB_RasterTilesDataset(dem_dir=dem_dir, so_dir=so_dir, rgb_dir=rgb_dir, transform=transform)

# DataLoader

n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=number_of_workers, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=number_of_workers, pin_memory=True, drop_last=True)

print("Data is loaded")

Data is loaded


In [82]:
model = RGB_DEM_to_SO(resnet_output_size=(8, 8), fusion_output_size=(128, 128)).to(device)

In [84]:
from torch.optim import Adam
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

arg_nottest = True

# Training loop
for epoch in range(epochs):
    train_metrics = {'Train/iou': 0}
    
    for i, batch in enumerate(tqdm(train_loader)):
        dem = batch['DEM'].to(device)
        so = batch['SO'].to(device)
        rgbs = [batch['RGB'][k].to(device) for k in range(6)]

        # Forward pass
        outputs = model(dem, rgbs)
        loss = criterion(outputs, so)
        iou = mIOU(so, outputs)
        train_metrics['Train/iou'] += iou

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if arg_nottest:
            continue
        else:
            break

    if arg_nottest:
        for k in train_metrics:
            train_metrics[k] /= len(train_loader)
    
    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {loss.item()}")
    print(train_metrics)
    
#         if (i+1) % 10 == 0:
#             print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

# print("Training completed.")

 14%|███████▉                                                  | 79/579 [00:51<05:26,  1.53it/s]


KeyboardInterrupt: 

In [40]:
batch = next(iter(train_loader))
dem = batch['DEM'].to(device)
so = batch['SO'].to(device)
rgbs = [batch['RGB'][k].to(device) for k in range(6)]




In [45]:
outputs = model(dem, rgbs)


torch.Size([1, 12288, 8, 8])


In [75]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [83]:
count_parameters(model)

13403529