In [50]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as data


from torchvision import datasets, transforms

In [2]:
class conv(nn.Module):
    def __init__(self, input_channel, output_channel):
        super(conv, self).__init__()
        
        self.conv1 = nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(output_channel)
        
        self.conv2 = nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(output_channel)
        
        self.relu = nn.ReLU()
        
    def forward(self, Z):
        Z = self.conv1(Z)
        Z = self.bn1(Z)
        Z = self.relu(Z)
        
        Z = self.conv2(Z)
        Z = self.bn2(Z)
        Z = self.relu(Z)
        
        return Z
    

class encoder(nn.Module):
    def __init__(self, input_channel, output_channel):
        super(encoder, self).__init__()
        
        self.conv = conv(input_channel, output_channel)
        self.pool = nn.MaxPool2d((2,2))
        
        
    def forward(self, Z):
        Z = self.conv(Z)
        P = self.pool(Z)
        
        return Z, P
    
class decoder(nn.Module):
    def __init__(self, input_channel, output_channel):
        super(decoder, self).__init__()
        
        self.up = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=2, stride=2, padding=0)
        self.conv = conv(output_channel+output_channel, output_channel)
        
    def forward(self, Z, skip):
        Z = self.up(Z)
        Z = torch.cat([Z, skip], axis=1)
        Z = self.conv(Z)
        
        return Z

class build_unet(nn.Module):
    def __init__(self):
        super(build_unet, self).__init__()
        
        """Encoder"""
        self.e1 = encoder(3, 64)
        self.e2 = encoder(64, 128)
        self.e3 = encoder(128, 256)
        #self.e4 = encoder(256, 512)
        
        """Bottleneck"""
        #self.b = conv(512, 1024)
        self.b = conv(256, 512)
        
        
        """Decoder"""
        self.d1 = decoder(512, 256)
        self.d2 = decoder(256, 128)
        self.d3 = decoder(128, 64)
        #self.d4 = decoder(128, 64)
        
        """Output"""
        self.output = nn.Conv2d(64, 2, kernel_size=1, padding=0)
        
    def forward(self, Z):
        Z1, P1 = self.e1(Z)
        Z2, P2 = self.e2(P1)
        Z3, P3 = self.e3(P2)
        #Z4, P4 = self.e4(Z3)
            
        b = self.b(P3)

        Z5 = self.d1(b, Z3)
        Z6 = self.d2(Z5, Z2)
        Z7 = self.d3(Z6, Z1)
        #Z8 = self.d4(Z7, P1)
        
        output = self.output(Z7)
        
        return output
        

In [3]:
def set_device():
    if torch.cuda.is_available():
        dev ="cuda:0"
    else:
        dev = "cpu"
    return torch.device(dev)

In [14]:
train_data_images = torchvision.datasets.ImageFolder(root = 'training/images', loader = torchvision.io.read_image)
train_data_groundtruth = torchvision.datasets.ImageFolder(root = 'training/groundtruth', loader = torchvision.io.read_image)
#train_loader_images = torch.utils.data.DataLoader(train_data_images)
#print(train_loader_images)

<torch.utils.data.dataloader.DataLoader object at 0x0000020387E67AF0>


In [47]:
train_images = []
for i in range (len(train_data_images)):
    train_images.append(torch.stack((train_data_images[i][0][0], train_data_images[i][0][1], train_data_images[i][0][2]), -1))

transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
train_set = datasets.MNIST("training/images", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=4,
        shuffle=True,  # Can be important for training
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
        num_workers=2,
    )
train_loader = train_loader.ToTensor()

AttributeError: 'DataLoader' object has no attribute 'ToTensor'

In [46]:
#inputs = torch.randn((2, 3, 512, 512))
inputs = train_loader
model = build_unet()
#inputs = inputs.to(set_device())
x = model(inputs)

TypeError: conv2d() received an invalid combination of arguments - got (DataLoader, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: ([31;1mDataLoader[0m, [31;1mParameter[0m, [31;1mParameter[0m, [31;1mtuple[0m, [31;1mtuple[0m, [31;1mtuple[0m, [32;1mint[0m)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: ([31;1mDataLoader[0m, [31;1mParameter[0m, [31;1mParameter[0m, [31;1mtuple[0m, [31;1mtuple[0m, [31;1mtuple[0m, [32;1mint[0m)


In [None]:
print(x.shape)

torch.Size([2, 2, 512, 512])
