In [7]:
import torch.nn as nn
import torch
import torchvision
from torchvision import models, transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os

In [8]:
class Compose(object):
	def __init__(self, transform_data, transform_target):
		self.transform_data = transform_data
		self.transform_target = transform_target

	def __call__(self, image, target):
		for t in self.transform_data:
			image = t(image)

		for t in self.transform_target:
			target = t(target)
			
		return image, target
	
class NonZeroToOne(object):
	def __call__(self, tensor):
		return torch.where(tensor !=0, torch.tensor(1), tensor)
	
transform_target = []
transform_target.append(transforms.Resize((512,512)))
transform_target.append(transforms.ToTensor())
#transform_target.append(NonZeroToOne())
transform_data = []
transform_data.append(transforms.Resize((512,512)))
transform_data.append(transforms.ToTensor())
transform = Compose(transform_data, transform_target)


VOC2012_train = torchvision.datasets.VOCSegmentation('./data', image_set="train",transforms=transform)
VOC2012_test = torchvision.datasets.VOCSegmentation('./data', image_set="val", transforms=transform)


batch_size = 64
trainloader = DataLoader(VOC2012_train, batch_size=batch_size, shuffle=True)
testloader = DataLoader(VOC2012_train, batch_size=batch_size, shuffle=True)

**Reference**

https://lee-jaewon.github.io/deep_learning_study/FCN/

https://medium.com/@msmapark2/fcn-%EB%85%BC%EB%AC%B8-%EB%A6%AC%EB%B7%B0-fully-convolutional-networks-for-semantic-segmentation-81f016d76204

https://gaussian37.github.io/vision-segmentation-fcn/

In [9]:
# input : (3, 512, 512)
# output_layer1 : (64, 256, 256)
# output_layer2 : (128, 128, 128)
# output_layer3 : (256, 64, 64)
# output_layer4 : (512, 32, 32)
# output_layer5 : (1024, 16, 16)

# output_fc6 : (4096, 16, 16)
# output_fc7 : (1000, 16, 16)
# output_fc8 : (10, 16, 16)

class FCN(nn.Module):
    def __init__(self, num_classes):
        super(FCN, self).__init__()
        self.layer1=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer2=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer3=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.pool3_pred=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=num_classes,kernel_size=1,stride=1,padding=0)
        )
        self.layer4=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.pool4_pred=nn.Sequential(
            nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1,stride=1,padding=0)
        )
        self.layer5=nn.Sequential(
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc6=nn.Sequential(
            nn.Conv2d(in_channels=512,out_channels=4096,kernel_size=1,stride=1,padding=0),
            nn.ReLU()
        )
        self.fc7=nn.Sequential(
            nn.Conv2d(in_channels=4096,out_channels=4096,kernel_size=1,stride=1,padding=0),
            nn.ReLU()
        )
        self.fc8=nn.Conv2d(in_channels=4096,out_channels=num_classes,kernel_size=1,stride=1,padding=0)
        
        self.x32_up=nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=2*32,stride=32,padding=16)
        self.x16_up=nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=2*16,stride=16,padding=8)
        self.x8_up=nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=2*8,stride=8,padding=4)
        self.x2_up=nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=2*2,stride=2,padding=1)

    def forward(self,x,type):
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        pool3_prediction=self.pool3_pred(x)

        x=self.layer4(x)
        pool4_prediction=self.pool4_pred(x)

        x=self.layer5(x)
        x=self.fc6(x)
        x=self.fc7(x)
        x=self.fc8(x)
        x2_upsample=self.x2_up(x)
        FCN_32s=self.x32_up(x)

        x2_upsample_with_pool4=x2_upsample+pool4_prediction
        FCN_16s=self.x16_up(x2_upsample_with_pool4)

        x4_upsample_with_pool4=self.x2_up(x2_upsample_with_pool4)
        FCN_8s=self.x8_up(x4_upsample_with_pool4+pool3_prediction)

        if type==8:
            return FCN_8s
        if type==16:
            return FCN_16s
        if type==32:
            return FCN_32s
        return 0


In [15]:


from tqdm import tqdm


learning_rate=1E-3
batch_size=100

model=FCN(num_classes=2)
optim=torch.optim.Adam(params=model.parameters(),lr=learning_rate)


type=8
epoch=3
loss_fn=nn.CrossEntropyLoss()

print(trainloader)

for epoch_cnt in range(epoch):
    last_loss=0
    for index, (data, target) in tqdm(enumerate(trainloader)):
        optim.zero_grad()
        prediction=model(data,type=type)
        loss=loss_fn(prediction,target)
        
        loss.backward()
        optim.step()
    
    with torch.no_grad():
        for index, (input,output) in enumerate(testloader):
            prediction=model(input)
            
            compare=prediction==output
            correct_area=torch.sum(compare).item()
            total_area=(batch_size*512*512)
            print(f"IoU : {correct_area / (batch_size * 512 * 512)}")

    
        

        


<torch.utils.data.dataloader.DataLoader object at 0x1022a32c0>


0it [00:28, ?it/s]


KeyboardInterrupt: 