In [34]:
from torch.utils.data import DataLoader
import os
import torch

from tqdm import tqdm
import numpy as np

from auxilary.utils import *
from dataset import nucleiDataset, nucleiValDataset
from networkModules.modelUnet3p import UNet_3Plus
from auxilary.lossFunctions import *

In [26]:
config = readConfig('configs/config.sys')

In [35]:
trainPaths = config["trainDataset"]
valPaths = config["valDataset"]
train_dataset = nucleiDataset(trainPaths, config)
val_dataset = nucleiValDataset(valPaths, config)
trainLoader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_data = DataLoader(val_dataset,batch_size=1,num_workers=4)

In [42]:
model = UNet_3Plus(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = weightedDiceLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=config["learning_rate"], weight_decay=1e-5)

In [43]:
epochs = 1
for epoch in range(epochs):
    for batch in tqdm(trainLoader):
        
        image, mask = batch
        gt = mask.squeeze().float()

        gt = gt.type(torch.float32)
        loss = criterion(model(image.to(device)), gt.to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")


        pass

    print(f"Epoch {epoch+1}/{epochs} completed.")

  0%|          | 2/33263 [00:07<28:20:19,  3.07s/it]

Loss: 0.5626470446586609
Loss: 0.6106583476066589
Loss: 0.5436936616897583


  0%|          | 6/33263 [00:07<6:11:48,  1.49it/s] 

Loss: 0.5795609951019287
Loss: 0.5498371124267578
Loss: 0.5808554887771606


  0%|          | 8/33263 [00:07<4:03:49,  2.27it/s]

Loss: 0.5382869243621826
Loss: 0.584365725517273
Loss: 0.5504485368728638


  0%|          | 12/33263 [00:08<2:07:45,  4.34it/s]

Loss: 0.5642147064208984
Loss: 0.6215934753417969
Loss: 0.5497806668281555


  0%|          | 14/33263 [00:08<1:41:10,  5.48it/s]

Loss: 0.5522156953811646
Loss: 0.5560946464538574
Loss: 0.5174639821052551


  0%|          | 18/33263 [00:08<1:11:54,  7.71it/s]

Loss: 0.583935022354126
Loss: 0.5627903938293457
Loss: 0.5296497344970703


  0%|          | 20/33263 [00:08<1:01:51,  8.96it/s]

Loss: 0.5612513422966003
Loss: 0.586361289024353
Loss: 0.560641348361969


  0%|          | 24/33263 [00:09<53:14, 10.40it/s]  

Loss: 0.5884459018707275
Loss: 0.5406578779220581
Loss: 0.5814355611801147


  0%|          | 26/33263 [00:09<51:18, 10.80it/s]

Loss: 0.5959405303001404
Loss: 0.539465069770813
Loss: 0.566691517829895


  0%|          | 28/33263 [00:09<48:40, 11.38it/s]

Loss: 0.5956835150718689
Loss: 0.5648857355117798


  0%|          | 32/33263 [00:09<50:32, 10.96it/s]

Loss: 0.5744208097457886
Loss: 0.557259202003479
Loss: 0.5626167058944702


  0%|          | 34/33263 [00:10<48:48, 11.35it/s]

Loss: 0.5252679586410522
Loss: 0.5473206639289856
Loss: 0.5506742000579834


  0%|          | 38/33263 [00:10<46:55, 11.80it/s]

Loss: 0.5872588753700256
Loss: 0.585180401802063
Loss: 0.5632878541946411


  0%|          | 40/33263 [00:10<46:09, 12.00it/s]

Loss: 0.5752347707748413
Loss: 0.5596056580543518
Loss: 0.5734161734580994


  0%|          | 44/33263 [00:10<46:34, 11.89it/s]

Loss: 0.5609489679336548
Loss: 0.5727795362472534
Loss: 0.5510673522949219


  0%|          | 46/33263 [00:11<46:14, 11.97it/s]

Loss: 0.5389384031295776
Loss: 0.5793780088424683
Loss: 0.5200209617614746


  0%|          | 50/33263 [00:11<45:12, 12.24it/s]

Loss: 0.5559569597244263
Loss: 0.5233997106552124
Loss: 0.5676417350769043


  0%|          | 52/33263 [00:11<44:05, 12.56it/s]

Loss: 0.5877354145050049
Loss: 0.6130759716033936
Loss: 0.5702688694000244


  0%|          | 54/33263 [00:11<43:48, 12.63it/s]

Loss: 0.5273369550704956
Loss: 0.6122372150421143


  0%|          | 56/33263 [00:11<48:53, 11.32it/s]

Loss: 0.5615043640136719
Loss: 0.5537708401679993


  0%|          | 60/33263 [00:12<52:27, 10.55it/s]

Loss: 0.5594611167907715
Loss: 0.5543913841247559
Loss: 0.5828352570533752


  0%|          | 62/33263 [00:12<51:58, 10.65it/s]

Loss: 0.5369651913642883
Loss: 0.5805069208145142
Loss: 0.5751917958259583


  0%|          | 66/33263 [00:12<49:18, 11.22it/s]

Loss: 0.5905096530914307
Loss: 0.5663304924964905
Loss: 0.5203932523727417


  0%|          | 68/33263 [00:13<48:40, 11.36it/s]

Loss: 0.5858075618743896
Loss: 0.5517416596412659
Loss: 0.5551397800445557


  0%|          | 72/33263 [00:13<45:20, 12.20it/s]

Loss: 0.5196487903594971
Loss: 0.5543729066848755
Loss: 0.5737084746360779


  0%|          | 74/33263 [00:13<45:52, 12.06it/s]

Loss: 0.5565739870071411
Loss: 0.5652649402618408
Loss: 0.5553887486457825


  0%|          | 78/33263 [00:13<47:27, 11.66it/s]

Loss: 0.5615660548210144
Loss: 0.5299070477485657
Loss: 0.5209241509437561


  0%|          | 80/33263 [00:14<48:49, 11.33it/s]

Loss: 0.5165789723396301
Loss: 0.5699918270111084
Loss: 0.5369954109191895


  0%|          | 82/33263 [00:14<47:19, 11.68it/s]

Loss: 0.5410093069076538
Loss: 0.5822775959968567


  0%|          | 84/33263 [00:14<54:33, 10.13it/s]

Loss: 0.5238839983940125
Loss: 0.5577692985534668


  0%|          | 88/33263 [00:14<51:05, 10.82it/s]

Loss: 0.5515720844268799
Loss: 0.5461410284042358
Loss: 0.5397001504898071


  0%|          | 90/33263 [00:15<49:43, 11.12it/s]

Loss: 0.5982769727706909
Loss: 0.578934907913208
Loss: 0.517635703086853


  0%|          | 94/33263 [00:15<46:48, 11.81it/s]

Loss: 0.603163480758667
Loss: 0.584244966506958
Loss: 0.5445002317428589


  0%|          | 96/33263 [00:15<46:26, 11.90it/s]

Loss: 0.5158274173736572
Loss: 0.5185099840164185
Loss: 0.5632999539375305


  0%|          | 100/33263 [00:15<45:45, 12.08it/s]

Loss: 0.5294826030731201
Loss: 0.5431348085403442
Loss: 0.5735999941825867


  0%|          | 100/33263 [00:15<1:27:54,  6.29it/s]


KeyboardInterrupt: 