In [1]:
import numpy as np
import cv2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import ternausnet
import ternausnet.models
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_names = os.listdir(img_dir) 

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx]) # Assumes masks have same filenames as images

        image = Image.open(img_path)
        if image.mode != 'RGB':
            image = Image.merge("RGB", (image, image, image))

        image = self.transform(image)

        return image,self.img_names[idx]

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = SegmentationDataset(img_dir='D:/study/dl/hackathon/github directory/input_image',transform=transform)
batch_S = 1
dataloader = DataLoader(dataset, batch_size=batch_S, shuffle=True)

In [4]:
class DiceLoss(nn.Module):
    def forward(self, input, target):
        smooth = 1.
        iflat = input.view(-1)
        tflat = target.view(-1)
        intersection = (iflat * tflat).sum()
        
        return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

class IoULoss(nn.Module):
    def forward(self, input, target):
        smooth = 1.
        intersection = (input * target).sum()
        total = (input + target).sum()
        union = total - intersection 
        
        return 1 - ((intersection + smooth) / (union + smooth))

In [5]:
dataloader_test =  dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load("D:/study/dl/hackathon/making dataset/models/model_try2_14.pth", map_location=device)
model = ternausnet.models.UNet11(pretrained=True)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

criterion = DiceLoss()
optimizer = Adam(model.parameters(), lr=0.001)


model = model.to(device)



In [6]:
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Don't compute gradients
    k=0
    for images, img_names in dataloader:
        k+=1
        images = images.to(device)
        outputs = model(images)
        outputs = torch.sigmoid(outputs)
        for lk in range(batch_S):
            pk = outputs.cpu()
            image3 = pk[lk].squeeze().detach().numpy()
            image3 = (image3 * 255).astype('uint8')
            cv2.imwrite(f"D:/study/dl/hackathon/github directory/first_stage_outputs/{img_names[lk]}",image3)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = SegmentationDataset(img_dir='D:/study/dl/hackathon/github directory/first_stage_outputs',transform=transform)
batch_S = 1
dataloader = DataLoader(dataset, batch_size=batch_S, shuffle=True)

In [None]:
dataloader_test =  dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_stage2 = ternausnet.models.UNet11(pretrained=True)
model_stage2 = model_stage2.to(device)

checkpoint = torch.load("D:/study/dl/hackathon/making dataset/models/model_stage2_dice_loss_5.pth", map_location=device)
model_stage2.load_state_dict(checkpoint['model_state_dict'])
model_stage2 = model_stage2.to(device)

criterion = DiceLoss()
optimizer = Adam(model.parameters(), lr=0.001)

In [None]:
model_stage2.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Don't compute gradients
    k=0
    for images, img_names in dataloader:
        k+=1
        images = images.to(device)
        outputs = model_stage2(images)
        outputs = torch.sigmoid(outputs)
        for lk in range(batch_S):
            pk = outputs.cpu()
            image3 = pk[lk].squeeze().detach().numpy()
            image3 = (image3 * 255).astype('uint8')
            cv2.imwrite(f"D:/study/dl/hackathon/github directory/second_stage_outputs/{img_names[lk]}",image3)