#GPUの種類を表示

In [None]:
!nvidia-smi

#Driveのマウント

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#ライブラリインポート

In [None]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import time
import subprocess as sp
from datetime import datetime, timedelta, timezone
import math
from torchvision import datasets, models, transforms
vgg19 = models.vgg19(pretrained=True)

#GPUの使用時間を表示

In [None]:
res = sp.Popen(["cat", "/proc/uptime"], stdout=sp.PIPE)
    # 単位はHour
use_time = float(sp.check_output(["awk", "{print $1 /60 /60 }"], stdin=res.stdout).decode().replace("\n",""))
print(use_time)

#パラメータの定義

In [None]:
%cd drive/My\ Drive
dataroot1="new_persons"
dataroot2="new_cloths"
dataroot3="new_segmentations"
num_thread=0
batch_size=16
num_epoch=15
img_size=(128,96)
lr=0.0002
b1=0.5
b2=0.999
ngpu=1

#データセットのロード

In [None]:
p_dataset=dset.ImageFolder(root=dataroot1,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))
c_dataset=dset.ImageFolder(root=dataroot2,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))
seg_dataset=dset.ImageFolder(root=dataroot3,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))

p_dataloader=torch.utils.data.DataLoader(p_dataset,batch_size=batch_size,shuffle=False,num_workers=num_thread)
c_dataloader=torch.utils.data.DataLoader(c_dataset,batch_size=batch_size,shuffle=False,num_workers=num_thread)
shuffle_c_dataloader=torch.utils.data.DataLoader(c_dataset,batch_size=batch_size,shuffle=True,num_workers=num_thread)
seg_dataloader=torch.utils.data.DataLoader(seg_dataset,batch_size=batch_size,shuffle=False,num_workers=num_thread)      
device=torch.device("cuda:0")  

#各種関数定義

##衣服マスクの1チャネル化

In [None]:
def onechanel_trans(x):
    y=torch.zeros(1,x.size(0),1,128,96)
    for k in range(x.size(0)):
        for i in range(128):
          for j in range(96):
            G=x[k][1][i][j].item()
            if (x[k][0][i][j]==1)and(round(G,4)==0.3333)and(x[k][2][i][j]==0):
              y[0][k][0][i][j]=1
    return y

##マスクを人物画像にαブレンド

In [None]:
def overlay(x,y):
    for k in range(x.size(0)):
        for i in range(128):
          for j in range(96):
            if x[k][0][i][j]:
              #y[k][0][i][j]=1
              #y[k][1][i][j]=1
              #y[k][2][i][j]=1
              y[k,:,i,j]=1
    return y

##重みの初期化関数

In [None]:
def init_weights(model):
  if isinstance(model.modules,nn.Conv2d):
      model.modules().weight.data.nomal_(0,0.002)
      model.modules().bias.data.zero_()
  if isinstance(model.modules,nn.ConvTranspose2d):
      model.modules().weight.data.nomal_(0,0.002)
      model.modules().bias.data.zero_()

#各種クラス定義

##Encoder-Decoder

In [None]:
class EncoderDecoder(nn.Module):
  def __init__(self,ngpu):
    super(EncoderDecoder,self).__init__()
    self.ngpu=ngpu
    #Encoder1
    self.down1_1=nn.Conv2d(4,64,2,2,0,bias=False)
    self.Leaky=nn.LeakyReLU(0.2,inplace=True)
    self.down1_2=nn.Conv2d(64,128,2,2,0,bias=False)
    self.batch1_1=nn.BatchNorm2d(128)
    self.down1_3=nn.Conv2d(128,256,2,2,0,bias=False)
    self.batch1_2=nn.BatchNorm2d(256)
    self.down1_4=nn.Conv2d(256,512,2,2,0,bias=False)
    self.batch1_3=nn.BatchNorm2d(512)
    self.down1_5=nn.Conv2d(512,1024,2,2,0,bias=False)
    self.batch1_4=nn.BatchNorm2d(1024)

    #Encoder2
    self.down2_1=nn.Conv2d(3,64,2,2,0,bias=False)
    self.down2_2=nn.Conv2d(64,128,2,2,0,bias=False)
    self.batch2_1=nn.BatchNorm2d(128)
    self.down2_3=nn.Conv2d(128,256,2,2,0,bias=False)
    self.batch2_2=nn.BatchNorm2d(256)
    self.down2_4=nn.Conv2d(256,512,2,2,0,bias=False)
    self.batch2_3=nn.BatchNorm2d(512)
    self.down2_5=nn.Conv2d(512,1024,2,2,0,bias=False)
    self.batch2_4=nn.BatchNorm2d(1024)

    #Decoder
    self.up3_1=nn.ConvTranspose2d(2048,1024,2,2,0,bias=False)
    self.batch3_1=nn.BatchNorm2d(1024)
    self.relu=nn.ReLU(True)
    self.up3_2=nn.ConvTranspose2d(2048,512,2,2,0,bias=False)
    self.batch3_2=nn.BatchNorm2d(512)
    self.up3_3=nn.ConvTranspose2d(1024,256,2,2,0,bias=False)
    self.batch3_3=nn.BatchNorm2d(256)
    self.up3_4=nn.ConvTranspose2d(512,128,2,2,0,bias=False)
    self.batch3_4=nn.BatchNorm2d(128)
    self.up3_5=nn.ConvTranspose2d(256,64,2,2,0,bias=False)
    self.batch3_5=nn.BatchNorm2d(64)
    self.down3_1=	nn.Conv2d(64,3,1,1,0,bias=False)
    self.Sigmoid=nn.Sigmoid()

  def Encoder1(self,input1): 
    down1_1=self.down1_1(input1)
    self.Leaky1_1=self.Leaky(down1_1)
    down1_2=self.down1_2(self.Leaky1_1)
    batch1_1=self.batch1_1(down1_2)
    self.Leaky1_2=self.Leaky(batch1_1)
    down1_3=self.down1_3(self.Leaky1_2)
    batch1_2=self.batch1_2(down1_3)
    self.Leaky1_3=self.Leaky(batch1_2)
    down1_4=self.down1_4(self.Leaky1_3)
    batch1_3=self.batch1_3(down1_4)
    self.Leaky1_4=self.Leaky(batch1_3)
    down1_5=self.down1_5(self.Leaky1_4)
    batch1_4=self.batch1_4(down1_5)
    self.out1=self.Leaky(batch1_4)
    return self.out1

  def Encoder2(self,input2):
    down2_1=self.down2_1(input2)
    self.Leaky2_1=self.Leaky(down2_1)
    down2_2=self.down2_2(self.Leaky2_1)
    batch2_1=self.batch2_1(down2_2)
    self.Leaky2_2=self.Leaky(batch2_1)
    down2_3=self.down2_3(self.Leaky2_2)
    batch2_2=self.batch2_2(down2_3)
    self.Leaky2_3=self.Leaky(batch2_2)
    down2_4=self.down2_4(self.Leaky2_3)
    batch2_3=self.batch2_3(down2_4)
    self.Leaky2_4=self.Leaky(batch2_3)
    down2_5=self.down2_5(self.Leaky2_4)
    batch2_4=self.batch2_4(down2_5)
    self.out2=self.Leaky(batch2_4)
    return self.out2
  
  def Decoder(self):
    #print(self.out1.size(),self.out2.size())
    Concatenate1=torch.cat([self.out1,self.out2],dim=1)
    up3_1=self.up3_1(Concatenate1)
    batch3_1=self.batch3_1(up3_1)
    relu3_1=self.relu(batch3_1)
    Concatenate2=torch.cat([relu3_1,self.Leaky1_4,self.Leaky2_4],dim=1)
    up3_2=self.up3_2(Concatenate2)
    batch3_2=self.batch3_2(up3_2)
    relu3_2=self.relu(batch3_2)
    Concatenate3=torch.cat([relu3_2,self.Leaky1_3,self.Leaky2_3],dim=1)
    up3_3=self.up3_3(Concatenate3)
    batch3_3=self.batch3_3(up3_3)
    relu3_3=self.relu(batch3_3)
    Concatenate4=torch.cat([relu3_3,self.Leaky1_2,self.Leaky2_2],dim=1)
    up3_4=self.up3_4(Concatenate4)
    batch3_4=self.batch3_4(up3_4)
    relu3_4=self.relu(batch3_4)
    Concatenate5=torch.cat([relu3_4,self.Leaky1_1,self.Leaky2_1],dim=1)
    up3_5=self.up3_5(Concatenate5)
    batch3_5=self.batch3_5(up3_5)
    relu3_5=self.relu(batch3_5)
    down3_1=self.down3_1(relu3_5)
    out3=self.Sigmoid(down3_1)
    #print(out3)
    return out3

##Discriminator

In [None]:
class Discriminator(nn.Module):
  def __init__(self,ngpu):
    super(Discriminator,self).__init__()
    self.ngpu=ngpu
    self.down1=nn.Conv2d(6,64,2,2,0,bias=False)
    self.Leaky=nn.LeakyReLU(0.2,inplace=True)
    self.down2=nn.Conv2d(64,128,2,2,0,bias=False)
    self.batch1=nn.BatchNorm2d(128)
    self.down3=nn.Conv2d(128,256,2,2,0,bias=False)
    self.batch2=nn.BatchNorm2d(256)
    self.down4=nn.Conv2d(256,512,2,2,0,bias=False)
    self.batch3=nn.BatchNorm2d(512)
    self.down5=nn.Conv2d(512,1024,2,2,0,bias=False)
    self.lastdown=nn.Conv2d(1024,1,(4,3),1,0,bias=False)
    self.sigmoid=nn.Sigmoid()
  
  def Discriminator(self,input1,input2):
    
    input=torch.cat([input1,input2],dim=1)
    #print(input1.size())
    #print(input2.size())
    down1=self.down1(input) 
    Leaky1=self.Leaky(down1)
    down2=self.down2(Leaky1)
    batch1=self.batch1(down2)
    Leaky2=self.Leaky(batch1)
    down3=self.down3(Leaky2)
    batch2=self.batch2(down3)
    Leaky3=self.Leaky(batch2)
    down4=self.down4(Leaky3)
    batch3=self.batch3(down4)
    Leaky4=self.Leaky(batch3)
    down5=self.down5(Leaky4)
    lastdown=self.lastdown(down5)
    #print(lastdown)
    
    out=self.sigmoid(lastdown)
    #print(out)
    return out
    


##知覚損失

In [None]:
class Vgg19Loss(nn.Module):
    def __init__(self):
        super(Vgg19Loss, self).__init__()
        features1=list(vgg19.features)[:3]
        features2=list(vgg19.features)[:8]
        features3=list(vgg19.features)[:13]
        features4=list(vgg19.features)[:22]
        features5=list(vgg19.features)[:31]
        self.features1=nn.ModuleList(features1).eval()
        self.features2=nn.ModuleList(features2).eval()
        self.features3=nn.ModuleList(features3).eval()
        self.features4=nn.ModuleList(features4).eval()
        self.features5=nn.ModuleList(features5).eval()

    def forward(self,x,y):
        t1=x
        t2=y
        loss1=nn.MSELoss()
        loss2=nn.MSELoss()
        loss3=nn.MSELoss()
        loss4=nn.MSELoss()
        loss5=nn.MSELoss()

        for f in self.features1:
            x=f(x)
            y=f(y)
        #print(x,y)
        f1loss=torch.sqrt(loss1(x,y)*64*128*96)
    
        
        x=t1
        y=t2
        for f in self.features2:
            x=f(x)
            y=f(y)
       # print(x.size(),y.size())
        f2loss=torch.sqrt(loss2(x,y)*128*64*48)

        x=t1
        y=t2
        for f in self.features3:
            x=f(x)
            y=f(y)
       # print(x.size(),y.size())
        f3loss=torch.sqrt(loss3(x,y)*256*32*24)

        x=t1
        y=t2
        for f in self.features4:
            x=f(x)
            y=f(y)
        #print(x.size(),y.size())
        f4loss=torch.sqrt(loss4(x,y)*512*16*12)

        x=t1
        y=t2
        for f in self.features5:
            x=f(x)
            y=f(y)
        #print(x.size(),y.size())
        f5loss=torch.sqrt(loss5(x,y)*512*8*6)

        return f1loss+f2loss+f3loss+f4loss+f5loss
        

#Unet(Encoder-Decoder)とDiscriminatorの初期化

In [None]:
Unet=EncoderDecoder(ngpu).to(device)
Unet.apply(init_weights)

Discriminator=Discriminator(ngpu).to(device)
Discriminator.apply(init_weights)


#損失関数&Adam最適化の定義

In [None]:
Ladv=nn.BCELoss()
L1=nn.L1Loss()
Lper=Vgg19Loss().to(device)

real_label=1
fake_label=0

optimizerUnet=optim.Adam(Unet.parameters(),lr=lr,betas=(b1,b2))
optimizerDiscriminator=optim.Adam(Discriminator.parameters(),lr=lr,betas=(b1,b2))

#学習

In [None]:
img_list = []
Unet_losses = []
Generator_losses=[]
Discriminator_losses = []
iters = 0
bepoch=0
print("Starting Training Loop...")
start=time.time()
for epoch in range(num_epoch):
   epoch=epoch+bepoch
   print(epoch)
   iter_p = iter(p_dataloader)
   iter_sub_p=iter(p_dataloader)
   iter_c = iter(c_dataloader)
   iter_shuffle_c=iter(shuffle_c_dataloader)
   iter_seg=iter(seg_dataloader)
   t1=time.time()
   i=0
   for p_data in p_dataloader:
       #0.バッチデータの取得
        p_real_batch=next(iter_p)
        sub_p_real_batch=next(iter_sub_p)
        c_real_batch=next(iter_c)
        c_shuffle_real_batch=next(iter_shuffle_c)
        seg_real_batch=next(iter_seg)
       
        p_b=p_real_batch[0].to(device)              #[][][][]
        sub_p_b=sub_p_real_batch[0].to(device)      #[][][][]
        c_t=c_real_batch[0].to(device)  
        c_shuffle=c_shuffle_real_batch[0].to(device)            #[][][][]
        seg_real_batch=seg_real_batch[0].to(device) #[][][][]


        #1.衣服のマスクを作成
        c_m=onechanel_trans(seg_real_batch).to(device)         #[][][][][]

        #2.衣服のマスクと人物画像を重ね合わせる
        p_m=overlay(c_m[0],sub_p_b).to(device)                 #[][][][]
        
        #3.Unetに衣服マスクと衣服を除いた人物画像、衣服画像を入力
        Unet.Encoder1(torch.cat([p_m,c_m[0]],dim=1))
        Unet.Encoder2(c_t)
        p_a=Unet.Decoder().to(device)



        #4.Discriminatorの学習
        Discriminator.zero_grad()
        b_size = p_b.size(0)#バッチサイズを計算

        label = torch.full((b_size,), fill_value=real_label,dtype=torch.float32,device=device)#正解ラベルを設定
        output = Discriminator.Discriminator(p_b,c_t).view(-1)#Discriminatorの出力を計算
        Ladv_realD= Ladv(output, label)#損失を計算
        Ladv_realD.backward()#勾配を計算
        D_x = output.mean().item()

        label.fill_(fake_label)#偽物ラベルを設定
        output = Discriminator.Discriminator(p_a.detach(),c_t).view(-1) #Discriminatorの出力を計算
        Ladv_fakeD1= Ladv(output, label)#損失を計算
        Ladv_fakeD1.backward()#勾配を計算
        D_G_z1 = output.mean().item()

        output = Discriminator.Discriminator(p_b,c_shuffle).view(-1) #Discriminatorの出力を計算
        Ladv_fakeD2= Ladv(output, label)#損失を計算
        Ladv_fakeD2.backward()#勾配を計算
        D_G_z2 = output.mean().item()
        errD = Ladv_realD+ Ladv_fakeD1+Ladv_fakeD2#Tureの勾配とFalseの勾配を足す
        optimizerDiscriminator.step() #discriminatorを更新



        #5.Unetの学習
        Unet.zero_grad() 
        label.fill_(real_label)#正解ラベルを設定
        output = Discriminator.Discriminator(p_a,c_t).view(-1)#generatorが生成した画像を入力
        Ladv_fakeG = Ladv(output, label)#損失を計算    
        L1_fakeG=L1(p_a,p_b)#損失を計算
        Lper_fakeG=Lper.forward(p_a,p_b)#損失を計算
        L_G=Ladv_fakeG + L1_fakeG+Lper_fakeG#損失の合算
        L_G.backward()#勾配を計算
        D_G_z2 = output.mean().item()
        optimizerUnet.step()  #Unetを更新



        #6.損失を記録
        Unet_losses.append(L_G.item())
        Generator_losses.append(Ladv_fakeG.item())
        Discriminator_losses.append(errD.item())

        if (i%15==0):
          
            img_list.append(vutils.make_grid(p_a.detach(), padding=2, normalize=True))

        i+=1
        print(i)
   print(epoch)
   t2=time.time()
   print(t2-t1)
end=time.time()
print(end-start)

#テスト1


##テストデータのロード

In [None]:
test_dataroot1="test_persons"
test_dataroot2="test_cloths"
test_dataroot3="test_segmentations"

test_p_dataset=dset.ImageFolder(root=test_dataroot1,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))
test_c_dataset=dset.ImageFolder(root=test_dataroot2,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))
test_seg_dataset=dset.ImageFolder(root=test_dataroot3,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))

test_p_dataloader=torch.utils.data.DataLoader(test_p_dataset,batch_size=64,shuffle=False,num_workers=num_thread)
test_c_dataloader=torch.utils.data.DataLoader(test_c_dataset,batch_size=64,shuffle=False,num_workers=num_thread)
test_seg_dataloader=torch.utils.data.DataLoader(test_seg_dataset,batch_size=64,shuffle=False,num_workers=num_thread)      
device=torch.device("cuda:0")  

##テスト実行

In [None]:
test_iter_p = iter(test_p_dataloader)
test_iter_c = iter(test_c_dataloader)
test_iter_seg=iter(test_seg_dataloader)
for test_p_data in test_p_dataloader:
     #バッチデータの取得
     test_p_real_batch=next(test_iter_p)
     test_c_real_batch=next(test_iter_c)
     test_seg_real_batch=next(test_iter_seg)
       
     test_p_b=test_p_real_batch[0].to(device)              #[][][][]
     test_c_t=test_c_real_batch[0].to(device)              #[][][][]
     test_seg_real_batch=test_seg_real_batch[0].to(device) #[][][][]

     #1.衣服のセグメンテーションマップを作成
     test_c_m=onechanel_trans(test_seg_real_batch).to(device)         #[][][][][]

     #2.衣服のセグメンテーションマップと人物画像を重ね合わせる
     test_p_m=overlay(test_c_m[0],test_p_b).to(device)                 #[][][][]
        
     #3.Unetに衣服マスクと衣服を除いた人物画像、衣服画像を入力
     Unet.Encoder1(torch.cat([test_p_m,test_c_m[0]],dim=1))
     Unet.Encoder2(test_c_t)
     test_p_a=Unet.Decoder().to(device)

##生成画像を表示

In [None]:
image_num=64
plt.imshow(np.transpose(vutils.make_grid(test_p_a.detach().to(device)[:image_num], padding=2, normalize=True).cpu(),(1,2,0)))

#テスト2


##テストデータのロード


In [None]:
test_dataroot1="good_persons_s"
test_dataroot2="good_cloths_s"
test_dataroot3="good_segmentations_s"

test_p_dataset=dset.ImageFolder(root=test_dataroot1,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))
test_c_dataset=dset.ImageFolder(root=test_dataroot2,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))
test_seg_dataset=dset.ImageFolder(root=test_dataroot3,
                           transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),]))

test_p_dataloader=torch.utils.data.DataLoader(test_p_dataset,batch_size=1,shuffle=False,num_workers=num_thread)
test_c_dataloader=torch.utils.data.DataLoader(test_c_dataset,batch_size=10,shuffle=False,num_workers=num_thread)
test_seg_dataloader=torch.utils.data.DataLoader(test_seg_dataset,batch_size=1,shuffle=False,num_workers=num_thread)      
device=torch.device("cuda:0")

##テスト実行

In [None]:
stop=10
i=0
test_p_a=[]
test_iter_p = iter(test_p_dataloader)
test_iter_c = iter(test_c_dataloader)
test_iter_seg=iter(test_seg_dataloader)
test_c_real_batch=next(test_iter_c)
for test_p_data in test_p_dataloader:
        #バッチデータの取得
     test_p_real_batch=next(test_iter_p)
     test_seg_real_batch=next(test_iter_seg)

     m2=torch.cat([test_p_real_batch[0],test_p_real_batch[0]],dim=0)
     m4=torch.cat([m2,m2],dim=0)
     m8=torch.cat([m4,m4],dim=0)
     m10=torch.cat([m8,m2],dim=0)

     s2=torch.cat([test_seg_real_batch[0],test_seg_real_batch[0]],dim=0)
     s4=torch.cat([s2,s2],dim=0)
     s8=torch.cat([s4,s4],dim=0)
     s10=torch.cat([s8,s2],dim=0)

     test_p_b=m10.to(device)              #[][][][]
     test_c_t=test_c_real_batch[0].to(device)              #[][][][]
     test_seg_real_batch=s10.to(device) #[][][][]

     #1.衣服のセグメンテーションマップを作成
     test_c_m=onechanel_trans(test_seg_real_batch).to(device)         #[][][][][]

     #2.衣服のセグメンテーションマップと人物画像を重ね合わせる
     test_p_m=overlay(test_c_m[0],test_p_b).to(device)                 #[][][][]
        
     #3.Unetに衣服マスクと衣服を除いた人物画像、衣服画像を入力
     Unet.Encoder1(torch.cat([test_p_m,test_c_m[0]],dim=1))
     Unet.Encoder2(test_c_t)
     temp=Unet.Decoder().to(device)
     
     if i==0:
       test_p_a=temp
     else:
       test_p_a=torch.cat([test_p_a,temp],dim=0)
    
     if i==stop:
       break
     i+=1

##生成画像を表示

In [None]:
image_num=64
plt.imshow(np.transpose(vutils.make_grid(test_p_a[:13].detach(), padding=2, normalize=True).cpu(),(1,2,0)))

#モデルの退避

In [None]:
def avoid_Unet():    
    PATHUnet='drive/My Drive/models/Unet_6_5242_17_advloss.pth'
    torch.save({
            'epoch': epoch,
            'image':img_list,
            'Unet_losses':Unet_losses,
            'Generator':Generator_losses,
            'model_state_dict': Unet.state_dict(),
            'optimizer_state_dict': optimizerUnet.state_dict(),
            'criterion1':Ladv.state_dict(),
            'criterion2':L1.state_dict(),
            'criterion3':Lper.state_dict()

            }, PATHUnet)

In [None]:
def avoid_D():
    PATHD='drive/My Drive/models/Discriminator_6_5242_17_advloss.pth'
    torch.save({
            'Disciriminator_losses':Discriminator_losses,
            'model_state_dict': Discriminator.state_dict(),
            'optimizer_state_dict': optimizerDiscriminator.state_dict()
            }, PATHD)
    

In [None]:
avoid_Unet()
avoid_D()

#モデルのロード

In [None]:
%cd drive/My\ Drive

In [None]:
%cd ..

In [None]:
PATHG='drive/My Drive/models/Unet_2.7_5242_16.5_advloss.pth'
#Unet = EncoderDecoder(ngpu).to(device)
optimizerUnet= optim.Adam(Unet.parameters(),lr=lr,betas=(b1,b2))
Ladv=nn.BCELoss()
L1=nn.L1Loss()
Lper=Vgg19Loss().to(device)

checkpoint = torch.load(PATHG)
Unet.load_state_dict(checkpoint['model_state_dict'])
optimizerUnet.load_state_dict(checkpoint['optimizer_state_dict'])
Ladv.load_state_dict(checkpoint['criterion1'])
L1.load_state_dict(checkpoint['criterion2'])
Lper.load_state_dict(checkpoint['criterion3'])
bepoch = checkpoint['epoch']
img_list=checkpoint['image']
Unet_losses=checkpoint['Unet_losses']
Generator_losses=checkpoint['Generator']

In [None]:
PATHD='drive/My Drive/models/Discriminator_2.7_5242_16.5_advloss.pth'
#Discriminator=Discriminator(ngpu).to(device)
optimizerDiscriminator=optim.Adam(Discriminator.parameters(),lr=lr,betas=(b1,b2))

checkpoint = torch.load(PATHD)
Discriminator.load_state_dict(checkpoint['model_state_dict'])
optimizerDiscriminator.load_state_dict(checkpoint['optimizer_state_dict'])
Discriminator_losses=checkpoint['Disciriminator_losses']

#損失の表示

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(Generator_losses,label="G")
plt.plot( Discriminator_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.title("Unet Loss During Training")
plt.plot(Unet_losses,label="Unet")
#plt.plot( Discriminator_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()