<h1><center> Training a prompt model

# 1- Dataset

In [20]:

import numpy as np
import rasterio
import os 
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import torch
import rasterio
import numpy as np
from torch.utils.data import Dataset,  DataLoader
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch
import os
import torch.nn.functional as F
import numpy as np
import tqdm
from transformers import SamModel


In [None]:
model_type = "vit_h"
path_to_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth'

# Loading the model
sam_model = sam_model_registry[model_type](checkpoint=path_to_checkpoint)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
data_dir = "split"

train_img_dir = os.path.join(data_dir, 'train', 'img')
test_img_dir = os.path.join(data_dir, 'val', 'img')
train_msk_dir = os.path.join(data_dir, 'train', 'msk')
test_msk_dir = os.path.join(data_dir, 'val', 'msk')

In [62]:
def simple_equalization_8bit(im, percentiles=5):
    ''' im is a numpy array
        returns a numpy array
    '''
    out = np.zeros_like(im)
    # faire l'equalization par channel
    def equalize(im_channel):
        v_min, v_max = np.percentile(im_channel,percentiles),np.percentile(im_channel, 100 - percentiles)

        # Clip the image to the percentile values
        im_clipped = np.clip(im_channel, v_min, v_max)

        # Scale the image to the 0-255 range
        im_scaled = np.round((im_clipped - v_min) / (v_max - v_min))
        return im_scaled.astype(np.uint8)
    
    for channel in range(im.shape[0]):
        out[channel,:,:] = equalize(im[channel,:,:])
    
    return out

In [63]:
class S1S2Dataset(Dataset):
    def __init__(self, img_folder, mask_folder, processor, transform=None, target_transform=None,):
        self.img_folder = img_folder
        self.mask_folder = mask_folder
        self.transform = transform
        self.target_transform = target_transform
        self.img_filenames = [f for f in os.listdir(img_folder) if f.endswith('.tif') and "img" in f]
        self.processor = processor
        print(self.img_filenames)

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        filename = self.img_filenames[idx]
        img_path = os.path.join(self.img_folder, filename)
        
        # image
        with rasterio.open(img_path) as src:
            image = src.read().astype(np.float32)[0:3, :, :]
            image = simple_equalization_8bit(image, percentiles=5) 
            image = torch.from_numpy(image) # shape (C, H, W)
            
            image = self.processor(image, return_tensors="pt")
            # remove batch dimension
            image = {k: v.squeeze(0) for k, v in image.items()}

        # masque
        mask_filename = filename.replace("img", "msk")
        mask_path = os.path.join(self.mask_folder, mask_filename)
        with rasterio.open(mask_path) as src:
            mask = src.read()[0].astype(np.float32)
            mask = torch.from_numpy(mask)
            mask = mask.unsqueeze(0)  # Ajouter une dimension de canal (C, H, W)

        # if self.transform:
        #     image = self.transform(image)
        #     mask = self.target_transform(mask)

        return image, mask

transform= transforms.Compose([
    transforms.Resize((1024, 1024)), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_target = transforms.Compose([transforms.Resize((224, 224))])


In [64]:


from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

dataset = S1S2Dataset(train_img_dir, train_msk_dir, transform=transform, processor=processor)

['sentinel12_s2_8_img_669.tif', 'sentinel12_s2_5_img_615.tif', 'sentinel12_s2_6_img_1747.tif', 'sentinel12_s2_12_img_1454.tif', 'sentinel12_s2_15_img_24.tif', 'sentinel12_s2_9_img_203.tif', 'sentinel12_s2_1_img_1545.tif', 'sentinel12_s2_1_img_497.tif', 'sentinel12_s2_16_img_298.tif', 'sentinel12_s2_15_img_1130.tif', 'sentinel12_s2_9_img_565.tif', 'sentinel12_s2_12_img_1332.tif', 'sentinel12_s2_6_img_1021.tif', 'sentinel12_s2_6_img_458.tif', 'sentinel12_s2_5_img_173.tif', 'sentinel12_s2_12_img_1326.tif', 'sentinel12_s2_6_img_1035.tif', 'sentinel12_s2_5_img_167.tif', 'sentinel12_s2_12_img_768.tif', 'sentinel12_s2_1_img_1237.tif', 'sentinel12_s2_1_img_483.tif', 'sentinel12_s2_15_img_1124.tif', 'sentinel12_s2_9_img_571.tif', 'sentinel12_s2_15_img_30.tif', 'sentinel12_s2_9_img_217.tif', 'sentinel12_s2_15_img_1642.tif', 'sentinel12_s2_1_img_1551.tif', 'sentinel12_s2_5_img_601.tif', 'sentinel12_s2_6_img_1753.tif', 'sentinel12_s2_12_img_1440.tif', 'sentinel12_s2_1_img_1579.tif', 'sentinel12_s2

In [65]:

train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [66]:
print("train_loader:", len(train_loader))
print(next(iter(train_loader))[0]["pixel_values"].shape)

train_loader: 3290
torch.Size([4, 3, 1024, 1024])


# 2- SAM

In [67]:
# Fine-tuning the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [68]:
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [69]:
model = model.to(device)

# Define loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.mask_decoder.parameters(), lr=0.001, momentum=0.9)

model.train()
for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(train_loader, 0)):
        inputs, labels = data
        labels = labels.float().to(device)

        outputs = model(pixel_values=inputs["pixel_values"].to(device), multimask_output=False)
        
        predicted_masks = outputs.pred_masks.squeeze(1)
        print(' predicted_masks',predicted_masks.shape)
        print('labels', labels.shape)
        
        loss = criterion(predicted_masks, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print(loss.item())
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

# torch.save(model.state_dict(), PATH)

print('Finished Training')

0it [00:00, ?it/s]

 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


1it [00:16, 16.63s/it]

1.2275418043136597
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


2it [00:32, 15.90s/it]

3.6721062660217285
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


3it [00:48, 16.08s/it]

137.8570098876953
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


4it [01:01, 14.76s/it]

5.859615325927734
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


5it [01:13, 14.00s/it]

0.8498289585113525
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


6it [01:26, 13.59s/it]

1.8026902675628662
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


7it [01:39, 13.49s/it]

0.3483894467353821
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


8it [01:55, 14.21s/it]

0.40903347730636597
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


9it [02:08, 13.82s/it]

0.6625363826751709
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


10it [02:47, 21.66s/it]

0.38073602318763733
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


11it [03:00, 19.06s/it]

1.006957769393921
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


12it [03:13, 17.14s/it]

0.09925845265388489
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


13it [03:26, 15.77s/it]

0.538582444190979
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


14it [03:39, 14.91s/it]

0.055601876229047775
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


15it [03:55, 15.34s/it]

0.9860178232192993
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


16it [04:12, 15.79s/it]

0.8145720362663269
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


17it [04:35, 18.02s/it]

0.5232590436935425
 predicted_masks torch.Size([4, 1, 256, 256])
labels torch.Size([4, 1, 256, 256])


18it [04:55, 18.51s/it]

2.069855213165283


18it [05:01, 16.74s/it]


KeyboardInterrupt: 