In [None]:
import os
import cv2
from PIL import Image
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
#from torch.optim.lr_scheduler import _LRScheduler

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from torchvision import models
from torchsummary import summary

# GPU 사용이 가능할 경우, GPU를 사용할 수 있게 함.'
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
print(device)

print(os.environ.get('CUDA_VISIBLE_DEVICES'))

In [None]:
class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(self, x):
        return x.view_as(x)
    @staticmethod
    def backward(self, grad_output): # 역전파 시에 gradient에 음수를 취함
        return grad_output * (-1)

class domain_classifier(nn.Module):
    def __init__(self):
        super(domain_classifier, self).__init__()
        self.fc1 = nn.Linear(224*224*64, 10)
        self.fc2 = nn.Linear(10, 1) # source = 0, target = 1 회귀 가정

    def forward(self, x):
        x = x.view(-1, 224*224*64)
        x = GradReverse.apply(x) # gradient reverse
        x = F.leaky_relu(self.fc1(x))
        x = self.fc2(x)
        
        return torch.sigmoid(x)

#Unet의 기본이 되는 conv블럭
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size = 3):
        super(ConvBlock, self).__init__()
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1)  # 여기서 in_channels는 out_channels와 동일해야 합니다.
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x
class IdentityBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(IdentityBlock, self).__init__()
        
        # 3x3 convolution
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU()
        # 3x3 convolution
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        
        return out


#인코더 블럭
class Conv2(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(Conv2,self).__init__() 
        self.identityblock1 = IdentityBlock(in_channels,in_channels)
        self.identityblock2 = IdentityBlock(in_channels,out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
    def forward(self,x):
        
        x = self.identityblock1(x)
        x = self.identityblock2(x)
        p = self.maxpool(x)
        print("Conv2 x: ",x.shape,"maxpulling: ", p.shape)
        return x , p
class Conv3(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(Conv3,self).__init__()         
        self.identityblock1 = IdentityBlock(in_channels,in_channels)
        self.identityblock2 = IdentityBlock(in_channels,out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
    def forward(self,x):
        x = self.identityblock1(x)
        x = self.identityblock2(x)
        p = self.maxpool(x)
        print("Conv3 x: ",x.shape,"maxpulling: ", p.shape)
        return x , p
class Conv4(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(Conv4,self).__init__()         
        self.identityblock1 = IdentityBlock(in_channels,in_channels)
        self.identityblock2 = IdentityBlock(in_channels,out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
    def forward(self,x):
        x = self.identityblock1(x)
        x = self.identityblock2(x)
        p = self.maxpool(x)
        print("Conv4 x: ",x.shape,"maxpulling: ", p.shape)
        return x , p
class Conv5(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(Conv5,self).__init__() 
        self.identityblock1 = IdentityBlock(in_channels,in_channels)
        self.identityblock2 = IdentityBlock(in_channels,out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        
    def forward(self,x):
        x = self.identityblock1(x)
        x = self.identityblock2(x)
        p = self.maxpool(x)
        print("Conv5 x: ",x.shape,"maxpulling: ", p.shape)
        return x , p
#디코더 블럭
class DecoderBlock(nn.Module):
    def __init__(self, channels):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.ConvTranspose2d(channels*2, channels, kernel_size=4, stride=2, padding=1, output_padding=1) # output_padding 추가
        self.convblock1 = ConvBlock(channels*2, channels)

    def forward(self, x, skip):
        x = self.upsample(x)
        if x.size(2) != skip.size(2) or x.size(3) != skip.size(3):
            x = F.interpolate(x, size=(skip.size(2), skip.size(3)))
        x = torch.cat([x, skip], dim=1)
        x = self.convblock1(x)
        return x

#Unet구조 middle의 xm값의 움직임에 주의
class Unet(nn.Module):
    def __init__(self,n_classes):
        super(Unet,self).__init__()
        self.fconv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2)
        self.fbn1 = nn.BatchNorm2d(64)
        self.frelu1 = nn.ReLU()
        self.fmaxpooling = nn.MaxPool2d(kernel_size=3,stride=2)
        
        self.conv2 = Conv2(64,128)
        self.conv3 = Conv3(128,256)
        self.conv4 = Conv4(256,512)
        self.conv5 = Conv5(512,512)
        
        self.middleconv = ConvBlock(512,1024)
        self.dropout = nn.Dropout2d(0.4) #
           
        self.decoder5 = DecoderBlock(1024)
        self.decoder4 = DecoderBlock(512)
        self.decoder3 = DecoderBlock(256)
        self.decoder2 = DecoderBlock(128)
        self.decoder1 = DecoderBlock(64)
        
        self.segmap = nn.Conv2d(64,n_classes, kernel_size=1)
        self.domain_classifier = domain_classifier()
        
    def forward(self,x):
        x = self.fconv1(x)#3->64
        x = self.fbn1(x)
        x1 = self.frelu1(x)
        p = self.fmaxpooling(x1)#첫 conv: x0([8, 64, 109, 109]) p([8, 64, 54, 54])
        x2,p = self.conv2(p)#conv2:  x1:([8, 256, 54, 54]) p([8, 256, 26, 26])
        x3,p = self.conv3(p)#conv3:  x2([8, 512, 26, 26]) p([8, 512, 12, 12])
        x4,p = self.conv4(p)#conv4:  x3([8, 1024, 12, 12]) p([8, 1024, 5, 5])
        x5,p = self.conv5(p)#conv5:  x4([8, 2048, 5, 5]) p([8, 2048, 2, 2])
        
        xm = self.middleconv(p)#xm([8, 4096, 2, 2])
        xm = self.dropout(xm)
        
        x = self.decoder5(xm,x5)#뉴런:2048*2->2048 1
        x = self.decoder4(x,x4)#뉴런:1024*2->1024 
        x = self.decoder3(x,x3) #14
        x = self.decoder2(x,x2)#28
        x = self.decoder1(x,x1)#55
        
        x = F.interpolate(x, size=(224, 224))
        x_c = self.segmap(x)
        x_d = self.domain_classifier(x)
        #print(x.shape)
        return x_c,x_d