In [62]:
import torch.nn as nn
import torch
import torchvision
from torchvision import models, transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")


Using device: mps


In [63]:
class Compose(object):
	def __init__(self, transform_data, transform_target):
		self.transform_data = transform_data
		self.transform_target = transform_target

	def __call__(self, image, target):
		for t in self.transform_data:
			image = t(image)

		for t in self.transform_target:
			target = t(target)
			
		return image, target
	
class NonZeroToOne(object):
	def __call__(self, tensor):
		return torch.where(tensor !=0, torch.tensor(1), tensor)
	
transform_target = []
transform_target.append(transforms.Resize((224,224)))
transform_target.append(transforms.ToTensor())
transform_target.append(NonZeroToOne())
transform_data = []
transform_data.append(transforms.Resize((224,224)))
transform_data.append(transforms.ToTensor())
transform = Compose(transform_data, transform_target)


VOC2012_train = torchvision.datasets.VOCSegmentation('./data', image_set="train",transforms=transform)
VOC2012_test = torchvision.datasets.VOCSegmentation('./data', image_set="val", transforms=transform)

batch_size = 16
trainloader = DataLoader(VOC2012_train, batch_size=batch_size, shuffle=True)
testloader = DataLoader(VOC2012_train, batch_size=batch_size, shuffle=True)


In [64]:

#for i in range(200,300):
#    for j in range(200,300):
#        print((VOC2012_train[100][1][0][i][j].type(torch.int32)).item(),end="")
#    print("")
print(len(VOC2012_train))

print(VOC2012_train[0][1].shape)
print(VOC2012_train[0][0].shape)
print(VOC2012_test[0][1].shape)
print(VOC2012_test[0][0].shape)

1464
torch.Size([1, 224, 224])
torch.Size([3, 224, 224])
torch.Size([1, 224, 224])
torch.Size([3, 224, 224])


**Reference**

https://lee-jaewon.github.io/deep_learning_study/FCN/

https://medium.com/@msmapark2/fcn-%EB%85%BC%EB%AC%B8-%EB%A6%AC%EB%B7%B0-fully-convolutional-networks-for-semantic-segmentation-81f016d76204

https://gaussian37.github.io/vision-segmentation-fcn/

In [65]:
# input : (3, 224, 224)
# output_layer1 : (64, 112, 112)
# output_layer2 : (128, 56, 56)
# output_layer3 : (256, 28, 28)
# output_layer4 : (512, 14, 14)
# output_layer5 : (1024, 7, 7)

# output_fc6 : (4096, 7, 7)
# output_fc7 : (1000, 7, 7)
# output_fc8 : (10, 7, 7)

vgg16 = torchvision.models.vgg16(weights=True)

class FCN(nn.Module):
    def __init__(self, num_classes):
        super(FCN, self).__init__()
        self.conv1_1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1)
        self.relu1_1=nn.ReLU()
        self.conv1_2=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1)
        self.relu1_2=nn.ReLU()
        self.pool1=nn.MaxPool2d(2)

        self.conv2_1=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1)
        self.relu2_1=nn.ReLU()
        self.conv2_2=nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1)
        self.relu2_2=nn.ReLU()
        self.pool2=nn.MaxPool2d(2)

        self.conv3_1=nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1)
        self.relu3_1=nn.ReLU()
        self.conv3_2=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1)
        self.relu3_2=nn.ReLU()
        self.conv3_3=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1)
        self.relu3_3=nn.ReLU()
        self.pool3=nn.MaxPool2d(2)
        
        self.conv4_1=nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1)
        self.relu4_1=nn.ReLU()
        self.conv4_2=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1)
        self.relu4_2=nn.ReLU()
        self.conv4_3=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1)
        self.relu4_3=nn.ReLU()
        self.pool4=nn.MaxPool2d(2)

        self.conv5_1=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1)
        self.relu5_1=nn.ReLU()
        self.conv5_2=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1)
        self.relu5_2=nn.ReLU()
        self.conv5_3=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1)
        self.relu5_3=nn.ReLU()
        self.pool5=nn.MaxPool2d(2)
        
        self.score_pool3=nn.Conv2d(in_channels=256,out_channels=num_classes,kernel_size=1,stride=1,padding=0)
    
        self.score_pool4=nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1,stride=1,padding=0)
        
        self.fc6=nn.Sequential(
            nn.Conv2d(in_channels=512,out_channels=4096,kernel_size=1,stride=1,padding=0),
            nn.ReLU(),
            nn.Dropout2d()
        )
        self.fc7=nn.Sequential(
            nn.Conv2d(in_channels=4096,out_channels=4096,kernel_size=1,stride=1,padding=0),
            nn.ReLU(),
            nn.Dropout2d()
        )
        self.score_fr=nn.Conv2d(in_channels=4096,out_channels=num_classes,kernel_size=1,stride=1,padding=0)
        
        
        self.x8_up=nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=2*8,stride=8)
        self.x2_up=nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=2*2,stride=2)

    def forward(self,x):
        x=self.relu1_1(self.conv1_1(x))
        x=self.relu1_2(self.conv1_2(x))
        x=self.pool1(x)

        x=self.relu2_1(self.conv2_1(x))
        x=self.relu2_2(self.conv2_2(x))
        x=self.pool2(x)

        x=self.relu3_1(self.conv3_1(x))
        x=self.relu3_2(self.conv3_2(x))
        x=self.relu3_3(self.conv3_3(x))
        x=self.pool3(x)
        pool3_pred=self.score_pool3(x)

        x=self.relu4_1(self.conv4_1(x))
        x=self.relu4_2(self.conv4_2(x))
        x=self.relu4_3(self.conv4_3(x))
        x=self.pool4(x)
        pool4_pred=self.score_pool4(x)

        x=self.relu5_1(self.conv5_1(x))
        x=self.relu5_2(self.conv5_2(x))
        x=self.relu5_3(self.conv5_3(x))
        x=self.pool5(x)

        x=self.fc6(x)
        x=self.fc7(x)
        x=self.score_fr(x)
        us=self.x2_up(x)
        us=us[:,:,1:-1,1:-1]
    
        us_with_pool4_prediction=us+pool4_pred

        us2=self.x2_up(us_with_pool4_prediction)
        us2=us2[:,:,1:-1,1:-1]
        us2_with_pool3_prediction=us2+pool3_pred
        FCN_8s=self.x8_up(us2_with_pool3_prediction)
        FCN_8s=FCN_8s[:,:,4:-4,4:-4]
        return FCN_8s
    
    def copy_params_from_vgg16(self, vgg16):
        features=[
            self.conv1_1, self.relu1_1,
            self.conv1_2, self.relu1_2,
            self.pool1,
            self.conv2_1, self.relu2_1,
			self.conv2_2, self.relu2_2,
            self.pool2,
            self.conv3_1, self.relu3_1,
			self.conv3_2, self.relu3_2,
			self.conv3_3, self.relu3_3,
            self.pool3,
            self.conv4_1, self.relu4_1,
			self.conv4_2, self.relu4_2,
			self.conv4_3, self.relu4_3,
            self.pool4,
            self.conv5_1, self.relu5_1,
			self.conv5_2, self.relu5_2,
			self.conv5_3, self.relu5_3,
            self.pool5
        ]
        for l1, l2 in zip(vgg16.features, features):
            if isinstance(l1, nn.Conv2d) and isinstance(l2,nn.Conv2d):
                assert l1.weight.size()==l2.weight.size()
                assert l1.bias.size()==l2.bias.size()
                l2.weight.data.copy_(l1.weight.data)
                l2.bias.data.copy_(l1.bias.data)
        
        
    

In [66]:
from tqdm.auto import tqdm

vgg16 = torchvision.models.vgg16(weights=True)

#learning_rate=1E-3

model=FCN(num_classes=21).to(device)
model.copy_params_from_vgg16(vgg16)
#optim=torch.optim.Adam(params=model.parameters(),lr=learning_rate)


epoch=1
loss_fn=nn.CrossEntropyLoss()

#print(len(trainloader.dataset))
lr = 10**-1
momentum = 0.9
weight_decay = 2**-4
optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

#

for epoch_cnt in range(epoch):
    last_loss=0
    for index, (data, target) in tqdm(enumerate(trainloader)):
        #print(data.shape)
        data, target = data.to(device), target.to(device)
        optim.zero_grad()
        prediction=model(data)
        #print(prediction.max())
        target = torch.squeeze(target)
        target = target.long()
        loss=loss_fn(prediction,target)
        loss.backward()
        last_loss=last_loss+loss.item()
        optim.step()
    print(last_loss/len(trainloader))
    

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    for index, (input,output) in enumerate(trainloader):
        input,output=input.to(device), output.to(device)
        prediction=model(input)

        prediction=prediction[0].cpu()
        print(prediction)
        print(prediction.max())
        prediction = torch.argmax(prediction, dim=0)
        print(prediction.shape)
        print(prediction.max())

        cmap = plt.cm.get_cmap('tab20', 20)
        plt.subplot(121)
        for i in range(20):
            plt.imshow(prediction, cmap=cmap, alpha=0.5)
        plt.subplot(122)
        for i in range(20):
            plt.imshow(output.cpu()[0].squeeze(), cmap=cmap, alpha=0.5)

        break

tensor([[[ 2.5599e-01,  4.0321e-01,  2.9290e-01,  ...,  2.9288e-01,
           2.6524e-01,  3.2477e-01],
         [ 3.4972e-01,  3.3274e-01,  2.8565e-01,  ...,  3.4133e-01,
           3.5300e-01,  3.3788e-01],
         [ 3.0424e-01,  3.2251e-01,  2.7778e-01,  ...,  2.1243e-01,
           4.0355e-01,  3.5414e-01],
         ...,
         [ 3.7433e-01,  3.7313e-01,  2.8504e-01,  ...,  3.1121e-01,
           4.0119e-01,  2.6791e-01],
         [ 2.7598e-01,  3.4845e-01,  3.9136e-01,  ...,  3.4372e-01,
           3.2231e-01,  3.0929e-01],
         [ 4.4944e-01,  3.3752e-01,  3.0482e-01,  ...,  3.5178e-01,
           2.6787e-01,  2.3787e-01]],

        [[ 1.8478e-03, -1.0671e-01,  3.9923e-02,  ...,  3.0852e-02,
           3.2677e-03,  3.8513e-02],
         [ 5.1723e-03, -4.2358e-02, -2.9416e-02,  ..., -1.2927e-02,
          -5.1456e-02, -1.6787e-02],
         [ 3.6208e-02,  1.2849e-01,  4.3911e-02,  ...,  2.1054e-02,
           4.9695e-02, -3.3475e-02],
         ...,
         [ 8.7353e-03,  1

AttributeError: module 'matplotlib.cm' has no attribute 'get_cmap'