In [1]:
from transformers import UperNetForSemanticSegmentation, AutoImageProcessor
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from evaluate import load

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
img_dir = '/home/a-ploskin/repos/TerraLabel/data/task_0/data'
masks_dir = '/home/a-ploskin/repos/TerraLabel/data/masks'

In [3]:
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, image_processor, train, img_dir=img_dir, masks_dir=masks_dir):
        

        self.img_pathes = [ f'{img_dir}/{item}' for item in sorted(os.listdir(img_dir)) ]
        self.img_pathes = [item for item in self.img_pathes if item[-3:] == 'jpg']
        self.mask_pathes = [ f'{masks_dir}/{item}' for item in sorted(os.listdir(masks_dir)) ]

        if train == False: 
            self.img_pathes = self.img_pathes[:20]
            self.mask_pathes = self.mask_pathes[:20]
        
        self.image_processor = image_processor

        assert len(self.img_pathes) == len(self.mask_pathes), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.img_pathes)

    def __getitem__(self, idx):

        mask = np.load(self.mask_pathes[idx])
        image = np.array( Image.open(self.img_pathes[idx])) 

        encoded_inputs = self.image_processor(image, mask, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs
     

In [4]:
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large")
processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-swin-large")

model.to('cuda')

UperNetForSemanticSegmentation(
  (backbone): SwinBackbone(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0): SwinLayer(
              (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=192, out_features=192, bias=True)
                  (key): Linear(in_features=192, out_features=192, bias=True)
                  (value): Linear(in_features=192, out_features=192, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): Swin

In [5]:
train_ds = SemanticSegmentationDataset(processor, train=True)
eval_ds = SemanticSegmentationDataset(processor, train=False)

In [6]:
train_dataloader = DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=True)
eval_dataloader = DataLoader(eval_ds, batch_size=2, shuffle=True, drop_last=True)

In [7]:
log_path = './logs'
writer = SummaryWriter(log_path)
metric = load("mean_iou")

In [8]:
import torch
from tqdm import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
device = torch.device("cuda:0")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_loss_iter = []
train_loss_epoch = []
eval_iou = []
eval_acc = []
eval_loss = []
model.train()

for epoch in tqdm(range(200)):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    curr_epoch_loss = []
    curr_epoch_eval_loss = []
    for idx, batch in enumerate(train_dataloader):
        # print(idx)
        
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()
        #print()
        train_loss_iter.append(loss.item())
        curr_epoch_loss.append(loss.item())
        writer.add_scalar("Train/loss_step", train_loss_iter[-1], idx + epoch * len(train_dataloader))
        writer.add_scalar("Train/epoch", epoch + 1, idx + epoch * len(train_dataloader))
    train_loss_epoch.append(sum(curr_epoch_loss) / len(curr_epoch_loss))   
    writer.add_scalar("Train/loss_epoch", train_loss_epoch[-1], epoch + 1)

    with torch.no_grad():
        for batch in eval_dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss, logits = outputs.loss, outputs.logits
            curr_epoch_eval_loss.append(outputs.loss.item())
            upsampled_logits = torch.nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())
            
        metrics = metric.compute(num_labels=6, 
                                       ignore_index=255,
                                       reduce_labels=False, # we've already reduced the labels before)
            )
        eval_iou.append( metrics["mean_iou"])
        eval_acc.append(metrics["mean_accuracy"])
        eval_loss.append(sum(curr_epoch_eval_loss) / len(eval_dataloader))
        
        writer.add_scalar("Eval/loss",eval_loss[-1], epoch + 1)
        writer.add_scalar("Eval/Accuracy", metrics["mean_accuracy"], epoch + 1)
        writer.add_scalar("Eval/IoU", metrics["mean_iou"], epoch + 1)
        
        print("Mean_iou:", metrics["mean_iou"])
        print("Loss:", train_loss_epoch[-1])
        print("Mean accuracy:", metrics["mean_accuracy"])

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0


  0%|          | 1/200 [00:56<3:08:43, 56.90s/it]

Mean_iou: 0.33513295392727
Loss: 5.279483486966389
Mean accuracy: 0.48540135074217655
Epoch: 1


  1%|          | 2/200 [01:53<3:07:02, 56.68s/it]

Mean_iou: 0.4415849665942497
Loss: 2.995144341050125
Mean accuracy: 0.6112441620118598
Epoch: 2


  2%|▏         | 3/200 [02:50<3:06:07, 56.69s/it]

Mean_iou: 0.4333888309714377
Loss: 2.2524409977401176
Mean accuracy: 0.585202091765081
Epoch: 3


  2%|▏         | 4/200 [03:46<3:05:10, 56.69s/it]

Mean_iou: 0.6024412959191437
Loss: 1.8588783573813554
Mean accuracy: 0.6964514176383734
Epoch: 4


  2%|▎         | 5/200 [04:43<3:04:22, 56.73s/it]

Mean_iou: 0.5743818836211058
Loss: 1.1203867567748558
Mean accuracy: 0.709665122179954
Epoch: 5


  3%|▎         | 6/200 [05:40<3:03:26, 56.74s/it]

Mean_iou: 0.6477182858332481
Loss: 1.6687870621681213
Mean accuracy: 0.7176897505675166
Epoch: 6


  4%|▎         | 7/200 [06:37<3:02:29, 56.73s/it]

Mean_iou: 0.6725019149421166
Loss: 1.6355517797353791
Mean accuracy: 0.7420202116428954
Epoch: 7


  4%|▍         | 8/200 [07:33<3:01:34, 56.74s/it]

Mean_iou: 0.6221001300786968
Loss: 1.0238804132109736
Mean accuracy: 0.7736140562698998
Epoch: 8


  4%|▍         | 9/200 [08:30<3:00:34, 56.73s/it]

Mean_iou: 0.6881104968943639
Loss: 1.2509878045175133
Mean accuracy: 0.7808637938543755
Epoch: 9


  5%|▌         | 10/200 [09:27<2:59:43, 56.76s/it]

Mean_iou: 0.7302816116021508
Loss: 0.9044572741883558
Mean accuracy: 0.8182789993312477
Epoch: 10


  6%|▌         | 11/200 [10:24<2:58:53, 56.79s/it]

Mean_iou: 0.7111049114061491
Loss: 0.6396608217278632
Mean accuracy: 0.7854463097357357
Epoch: 11


  6%|▌         | 12/200 [11:20<2:57:52, 56.77s/it]

Mean_iou: 0.6693664040105247
Loss: 0.8014264964475865
Mean accuracy: 0.730130046860713
Epoch: 12


  6%|▋         | 13/200 [12:17<2:56:56, 56.77s/it]

Mean_iou: 0.7831486216011495
Loss: 0.5721659438639153
Mean accuracy: 0.8544375130137571
Epoch: 13


  7%|▋         | 14/200 [13:14<2:56:04, 56.80s/it]

Mean_iou: 0.7485736663718779
Loss: 0.5269277524203062
Mean accuracy: 0.8495763190806382
Epoch: 14


  8%|▊         | 15/200 [14:11<2:55:13, 56.83s/it]

Mean_iou: 0.7926909418979502
Loss: 0.7534447294182893
Mean accuracy: 0.865568251787816
Epoch: 15


  8%|▊         | 16/200 [15:08<2:54:16, 56.83s/it]

Mean_iou: 0.8025793157789013
Loss: 0.6807066912694675
Mean accuracy: 0.8616639047594945
Epoch: 16


  8%|▊         | 16/200 [15:40<3:00:20, 58.81s/it]


KeyboardInterrupt: 