In [1]:
from transformers import UperNetForSemanticSegmentation, AutoImageProcessor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from evaluate import load
from datetime import datetime

import sys
sys.path.append('..')

from training.dataset import SemanticSegmentationDataset
from training.trainer import SegmenterModeltrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large")
processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-swin-large")

In [3]:
img_dir = '/home/a-ploskin/repos/TerraLabel/data/task_0/data'
masks_dir = '/home/a-ploskin/repos/TerraLabel/data/masks'

train_ds, eval_ds = SemanticSegmentationDataset.get_train_and_eval_datasets(
    processor, img_dir, masks_dir
)

In [4]:
train_dataloader = DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=True)
eval_dataloader = DataLoader(eval_ds, batch_size=2, shuffle=True, drop_last=True)

In [5]:
log_path = './logs'
writer = SummaryWriter(log_path)
metric = load("mean_iou")

In [6]:
trainer = SegmenterModeltrainer(
    model=model,
    device='cuda:0',
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    metric=metric,
    writer=writer
)

file_path = f'./models/uppernet_{datetime.now()}'
segmenter = trainer.train(n_epochs=50, file_path=file_path)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 0


  2%|▏         | 1/50 [00:45<37:27, 45.88s/it]

Mean_iou: 0.27322078836219077
Loss: 5.1819842766071185
Mean accuracy: 0.41474975745371623
Epoch: 1


  4%|▍         | 2/50 [01:31<36:39, 45.82s/it]

Mean_iou: 0.27191924334378276
Loss: 3.6928261908991584
Mean accuracy: 0.4486458452335427
Epoch: 2


  6%|▌         | 3/50 [02:17<35:49, 45.74s/it]

Mean_iou: 0.33153283871283806
Loss: 2.1294291101653
Mean accuracy: 0.47619692874476854
Epoch: 3


  8%|▊         | 4/50 [03:01<34:32, 45.04s/it]

Mean_iou: 0.3309435363966511
Loss: 2.3312939590421218
Mean accuracy: 0.5256594530411991
Epoch: 4


 10%|█         | 5/50 [03:46<33:55, 45.23s/it]

Mean_iou: 0.3568862975138596
Loss: 1.4211649771394401
Mean accuracy: 0.4721316409304654
Epoch: 5


 12%|█▏        | 6/50 [04:30<32:54, 44.87s/it]

Mean_iou: 0.31773534588291935
Loss: 1.8184803066582516
Mean accuracy: 0.45798728381464565
Epoch: 6


 14%|█▍        | 7/50 [05:16<32:20, 45.14s/it]

Mean_iou: 0.3805451261350566
Loss: 1.2433922388430299
Mean accuracy: 0.5084797263135931
Epoch: 7


 16%|█▌        | 8/50 [06:02<31:43, 45.32s/it]

Mean_iou: 0.40039588386148167
Loss: 0.9446314963801153
Mean accuracy: 0.5652796589107263
Epoch: 8


 18%|█▊        | 9/50 [06:46<30:44, 44.98s/it]

Mean_iou: 0.3497703959458254
Loss: 1.5184520894083484
Mean accuracy: 0.505910311040904
Epoch: 9


 20%|██        | 10/50 [07:30<29:50, 44.75s/it]

Mean_iou: 0.3638247487753976
Loss: 1.018093732924297
Mean accuracy: 0.5276095534273362
Epoch: 10


 22%|██▏       | 11/50 [08:16<29:16, 45.05s/it]

Mean_iou: 0.37455566950183466
Loss: 0.9428259601880764
Mean accuracy: 0.5069465683492314
Epoch: 11


 24%|██▍       | 12/50 [09:02<28:37, 45.20s/it]

Mean_iou: 0.3875971294513143
Loss: 0.8056790337994181
Mean accuracy: 0.5291314624535364
Epoch: 12


 26%|██▌       | 13/50 [09:46<27:39, 44.84s/it]

Mean_iou: 0.39565134537575836
Loss: 0.8219754681761923
Mean accuracy: 0.5635943485543339
Epoch: 13


 28%|██▊       | 14/50 [10:31<27:03, 45.09s/it]

Mean_iou: 0.4499673841383087
Loss: 0.6037019557993988
Mean accuracy: 0.5585074368809994
Epoch: 14


  acc = total_area_intersect / total_area_label
 30%|███       | 15/50 [11:17<26:23, 45.23s/it]

Mean_iou: 0.4150918534493724
Loss: 0.4420018068932254
Mean accuracy: 0.6378675926385516
Epoch: 15


 32%|███▏      | 16/50 [12:01<25:25, 44.87s/it]

Mean_iou: 0.4515095502306688
Loss: 0.6426604918107904
Mean accuracy: 0.5857363667907763
Epoch: 16


 34%|███▍      | 17/50 [12:45<24:32, 44.61s/it]

Mean_iou: 0.46287300931738895
Loss: 0.7107241027827921
Mean accuracy: 0.5914294394801
Epoch: 17


 36%|███▌      | 18/50 [13:31<23:56, 44.90s/it]

Mean_iou: 0.43641163964250285
Loss: 0.3521282860431178
Mean accuracy: 0.5938006362320827
Epoch: 18


 38%|███▊      | 19/50 [14:16<23:18, 45.13s/it]

Mean_iou: 0.42385489877701027
Loss: 0.3403227046645921
Mean accuracy: 0.5887282439793468
Epoch: 19


 40%|████      | 20/50 [15:00<22:23, 44.79s/it]

Mean_iou: 0.42870785914385756
Loss: 0.548453553997237
Mean accuracy: 0.5809486368937423
Epoch: 20


 42%|████▏     | 21/50 [15:44<21:32, 44.57s/it]

Mean_iou: 0.39459977928875706
Loss: 0.604637203545406
Mean accuracy: 0.5696481832239937
Epoch: 21


 44%|████▍     | 22/50 [16:28<20:43, 44.43s/it]

Mean_iou: 0.4199279414736803
Loss: 0.7241327611022982
Mean accuracy: 0.5578769574951046
Epoch: 22


 46%|████▌     | 23/50 [17:12<19:55, 44.29s/it]

Mean_iou: 0.40718292416392754
Loss: 0.6318770490329841
Mean accuracy: 0.5811244714966612
Epoch: 23


 48%|████▊     | 24/50 [17:56<19:09, 44.20s/it]

Mean_iou: 0.4636881572807497
Loss: 0.5717968177692644
Mean accuracy: 0.5788161622887297
Epoch: 24


 50%|█████     | 25/50 [18:40<18:23, 44.15s/it]

Mean_iou: 0.41839376873455564
Loss: 0.5224194187542488
Mean accuracy: 0.5916923161523467
Epoch: 25


 52%|█████▏    | 26/50 [19:24<17:38, 44.10s/it]

Mean_iou: 0.42954039616625495
Loss: 0.6261489265437784
Mean accuracy: 0.6072630752789
Epoch: 26


 54%|█████▍    | 27/50 [20:08<16:54, 44.09s/it]

Mean_iou: 0.4687965970524253
Loss: 0.49618032060820483
Mean accuracy: 0.5951042054306086
Epoch: 27


 56%|█████▌    | 28/50 [20:52<16:10, 44.09s/it]

Mean_iou: 0.4432155354259602
Loss: 0.37340166728044377
Mean accuracy: 0.6225552998854259
Epoch: 28


 58%|█████▊    | 29/50 [21:36<15:25, 44.07s/it]

Mean_iou: 0.48295132094035065
Loss: 0.36240216447361584
Mean accuracy: 0.6396089057787918
Epoch: 29


 60%|██████    | 30/50 [22:22<14:52, 44.64s/it]

Mean_iou: 0.4851372111799377
Loss: 0.2578018145828411
Mean accuracy: 0.5938211400487322
Epoch: 30


 62%|██████▏   | 31/50 [23:07<14:05, 44.48s/it]

Mean_iou: 0.49064112178242564
Loss: 0.35404270126259535
Mean accuracy: 0.6056897388376058
Epoch: 31


 64%|██████▍   | 32/50 [23:51<13:19, 44.39s/it]

Mean_iou: 0.48576263795512936
Loss: 0.36151110653861845
Mean accuracy: 0.6115730877462415
Epoch: 32


 66%|██████▌   | 33/50 [24:35<12:35, 44.42s/it]

Mean_iou: 0.501867277320294
Loss: 0.35758901120902137
Mean accuracy: 0.6229552883344921
Epoch: 33


 68%|██████▊   | 34/50 [25:19<11:49, 44.37s/it]

Mean_iou: 0.478798756034211
Loss: 0.3122031817148472
Mean accuracy: 0.6303220607340506
Epoch: 34


  iou = total_area_intersect / total_area_union
 70%|███████   | 35/50 [26:03<11:03, 44.24s/it]

Mean_iou: 0.5663211761542736
Loss: 0.2678176937945958
Mean accuracy: 0.7179822934944762
Epoch: 35


 72%|███████▏  | 36/50 [26:47<10:18, 44.16s/it]

Mean_iou: 0.47284697882966825
Loss: 0.30064019804884645
Mean accuracy: 0.5963586706690313
Epoch: 36


 74%|███████▍  | 37/50 [27:33<09:38, 44.53s/it]

Mean_iou: 0.456040108473339
Loss: 0.2514503228253332
Mean accuracy: 0.6274904003034659
Epoch: 37


 76%|███████▌  | 38/50 [28:19<08:58, 44.90s/it]

Mean_iou: 0.45410104545845775
Loss: 0.24165597635096517
Mean accuracy: 0.5840838313176525
Epoch: 38


 78%|███████▊  | 39/50 [29:04<08:16, 45.17s/it]

Mean_iou: 0.48215161039608817
Loss: 0.22660645145665984
Mean accuracy: 0.6425382638236702
Epoch: 39


 80%|████████  | 40/50 [29:48<07:28, 44.85s/it]

Mean_iou: 0.4846306482376764
Loss: 0.23717511901310806
Mean accuracy: 0.6505310111324092
Epoch: 40


 82%|████████▏ | 41/50 [30:32<06:41, 44.59s/it]

Mean_iou: 0.46329104102004254
Loss: 0.23541904658336063
Mean accuracy: 0.5998296843635317
Epoch: 41


 84%|████████▍ | 42/50 [31:17<05:55, 44.43s/it]

Mean_iou: 0.46598740747257955
Loss: 0.3920380601595188
Mean accuracy: 0.6165838470722623
Epoch: 42


 86%|████████▌ | 43/50 [32:01<05:10, 44.36s/it]

Mean_iou: 0.42764043398395746
Loss: 0.28343339062071055
Mean accuracy: 0.5540723430858673
Epoch: 43


 88%|████████▊ | 44/50 [32:45<04:25, 44.27s/it]

Mean_iou: 0.44123676594210187
Loss: 0.3057170481710085
Mean accuracy: 0.5451933139257997
Epoch: 44


 90%|█████████ | 45/50 [33:29<03:41, 44.27s/it]

Mean_iou: 0.4806279659230857
Loss: 0.3123343954794109
Mean accuracy: 0.6440610953347438
Epoch: 45


 92%|█████████▏| 46/50 [34:13<02:57, 44.26s/it]

Mean_iou: 0.5012958597505803
Loss: 0.3090475155843486
Mean accuracy: 0.6621310340682703
Epoch: 46


 94%|█████████▍| 47/50 [34:57<02:12, 44.23s/it]

Mean_iou: 0.49020405845792303
Loss: 0.2797164999847782
Mean accuracy: 0.6433179232925288
Epoch: 47


 96%|█████████▌| 48/50 [35:42<01:28, 44.19s/it]

Mean_iou: 0.4818657997571663
Loss: 0.36560983801710195
Mean accuracy: 0.6391502020632417
Epoch: 48


 98%|█████████▊| 49/50 [36:30<00:45, 45.54s/it]

Mean_iou: 0.5116608894058454
Loss: 0.23481235833003603
Mean accuracy: 0.6385019885555122
Epoch: 49


100%|██████████| 50/50 [37:20<00:00, 44.80s/it]

Mean_iou: 0.5177258549714248
Loss: 0.2642797739084425
Mean accuracy: 0.6196782735269751



