In [0]:
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import (Dataset,DataLoader,TensorDataset)
from torchvision import models
from torch import nn,optim
import tqdm
from torchvision import models
from IPython.display import Image,display_jpeg
from torchvision.utils import save_image

google colabで動かすときはこのセルを実行

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')
!ls ./gdrive/'My Drive'/"Colab Notebooks"

In [0]:
!mkdir styletrain
!mkdir styletrain/jpg
!mkdir style
!mkdir style/jpg
!mkdir sample
!mkdir sample/jpg　

スタイル画像をstyle/jpg一枚のみ入れておく。また学習モニター用にサンプル画像をsample/jpgに一枚入れておく

In [0]:
style_data=ImageFolder("style/",transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor()]))
batch_size=1
style_loader=DataLoader(style_data,batch_size=batch_size,shuffle=True)
sample_data=ImageFolder("sample/",transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor()]))
batch_size=1
sample_loader=DataLoader(sample_data,batch_size=batch_size,shuffle=True)


1行目のパスは学習用データが入ったフォルダを選択する。

In [0]:
train_data=ImageFolder("gdrive/My Drive/deeplearnning/data",transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor()]))
batch_size=4
train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True)

In [0]:
class ConvNormRelu(nn.Module):
  def __init__(self,input_feature,output_feature,batch_size,stride,padding):
    super().__init__() 
    self.conv2d=nn.Conv2d(input_feature,output_feature,batch_size,stride=stride,padding=padding)
    self.relu=nn.ReLU()
    self.BatchNorm2d=nn.BatchNorm2d(output_feature)
  def forward(self,x):
    y=self.conv2d(x)
    y=self.BatchNorm2d(y)
    y=self.relu(y)
    return y
  
class TransConvNormRelu(nn.Module):
  def __init__(self,input_feature,output_feature,batch_size,stride,padding,output_padding):
    super().__init__() 
    self.convtrans2d=nn.ConvTranspose2d(input_feature,output_feature,batch_size,stride,padding,output_padding=output_padding)
    self.relu=nn.ReLU()
    self.BatchNorm2d=nn.BatchNorm2d(output_feature)
  def forward(self,x):
    y=self.convtrans2d(x)
    y=self.BatchNorm2d(y)
    y=self.relu(y)
    return y  

class residual(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv2d=nn.Conv2d(128,128,3,1,1)
    self.relu=nn.ReLU()
    self.BatchNorm2d=nn.BatchNorm2d(128)
  def forward(self,x):
    y=self.conv2d(x)
    y=self.BatchNorm2d(y)
    y=self.relu(y)
    y=self.conv2d(y)
    y=self.BatchNorm2d(y)
    y=y+x
    return y
  
net=nn.Sequential(
    ConvNormRelu(3,32,9,1,4),
    ConvNormRelu(32,64,3,2,1),
    ConvNormRelu(64,128,3,2,1),
    residual(),
    residual(),
    residual(),
    residual(),
    residual(),
    TransConvNormRelu(128,64,3,2,1,1),
    TransConvNormRelu(64,32,3,2,1,1),
    nn.ConvTranspose2d(32, 3, 9, 1, 4, bias=False),
    nn.BatchNorm2d(3),
    nn.Sigmoid()
)


#中間層の出力を取り出す関数
def midout(x,midnum,net):
  model=net[0:midnum+1]
  y = model(x)
  return y
#今回は3,8,15,22,を取り出す


vggnet=models.vgg16(pretrained=True)
for p in vggnet.parameters():
  p.requires_grad=False

  
def convarray(x):
  size=x.size()
  x=x.transpose(0,1)
  x = x.contiguous()
  x=x.view(size[1],-1)
  conv=torch.mm(x,x.t())
  conv=conv/(size[0]*size[2]*size[3])
  return conv

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.torch/models/vgg16-397923af.pth
100%|██████████| 553433881/553433881 [00:20<00:00, 27149817.42it/s]


In [0]:
style_array=[]
for z,_ in style_loader:
  x=torch.tensor(z,requires_grad=False)
y=midout(x,15,vggnet.features)
y=y.squeeze()
style_array.append(y)

y=midout(x,3,vggnet.features)
y=convarray(y)
style_array.append(y)
y=midout(x,8,vggnet.features)
y=convarray(y)
style_array.append(y)
y=midout(x,15,vggnet.features)
y=convarray(y)
style_array.append(y)
y=midout(x,22,vggnet.features)
y=convarray(y)
style_array.append(y)  

i=0
airi=[0,0]
for z,_ in sample_loader:
  airi[i]=torch.tensor(z,requires_grad=False)
  i+=1
airi[0]=airi[0].to("cuda:0")
airi[1]=airi[1].to("cuda:0")


In [0]:
def train_net(net, net2,train_loader,target ,optimizer_cls=optim.Adam,loss_fn=nn.MSELoss(),n_iter=10, device='cpu'):
    train_losses=[]
    optimizer=optimizer_cls(net.parameters())
    target0=target[0].to("cuda:0")
    target1=target[1].to("cuda:0")
    target2=target[2].to("cuda:0")
    target3=target[3].to("cuda:0")
    target4=target[4].to("cuda:0")
    #weightは重みパラメータ。スタイル画像の強さを決める
    #第一変数がそのほかに比べて小さいほどスタイル画像が強く表れる
    weight=torch.tensor([1.0,40.0,40.0,40.0,40.0])
    weight0=weight[0].to("cuda:0")
    weight1=weight[1].to(device)
    weight2=weight[2].to(device)
    weight3=weight[3].to(device)
    weight4=weight[4].to(device)
    target0=target0.detach()
    target1=target1.detach()
    target2=target2.detach()
    target3=target3.detach()
    target4=target4.detach()
    target0.requires_grad=False
    target1.requires_grad=False
    target2.requires_grad=False
    target3.requires_grad=False
    target4.requires_grad=False
    for epoch in range(n_iter):
        running_loss=0.0
        net.train()
        n=0
        score=0        
        for i,(xx,_) in tqdm.tqdm(enumerate(train_loader),total=len(train_loader)):
            xx=xx.to(device)
            xx2=net(xx)
            y_pred0=midout(xx2,15,net2)
            target_contents=midout(xx,15,net2)
            loss0=loss_fn(y_pred0,target_contents)
            loss0/=y_pred0.size()[0]
            y_pred1=midout(xx2,3,net2)
            y_pred1=convarray(y_pred1)
            loss1=loss_fn(y_pred1,target1)
            y_pred2=midout(xx2,8,net2)
            y_pred2=convarray(y_pred2)
            loss2=loss_fn(y_pred2,target2)
            y_pred3=midout(xx2,15,net2)
            y_pred3=convarray(y_pred3)
            loss3=loss_fn(y_pred3,target3)
            y_pred4=midout(xx2,22,net2)
            y_pred4=convarray(y_pred4)
            loss4=loss_fn(y_pred4,target4)
            
            loss=weight0*loss0+weight1*loss1+weight2*loss2+weight3*loss3+weight4*loss4
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss+=loss.item()
            n+=len(xx)
            if i%40==0:
              torch.save(
                  net.state_dict(),
                  "net_{:05d}.prm".format(epoch*10000+i),
                  pickle_protocol=4)
              generated_img=net(airi[0])
              generated_img2=net(airi[1])
              gazou=torch.cat([generated_img,generated_img2],3)
              save_image(gazou,"{:06d}.jpg".format((epoch+10)*10000+i))
              display_jpeg(Image("{:06d}.jpg".format((epoch+10)*10000+i)))
              print(epoch,i,loss.item())
              print(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item())
        train_losses.append(running_loss/len(train_loader))
        print(epoch,train_losses[-1],flush=True)
        

In [0]:
from IPython.display import Image,display_jpeg
from torchvision.utils import save_image

net=net.to("cuda:0")
vggnet=vggnet.to("cuda:0")
train_net(net,vggnet.features,train_loader,style_array,n_iter=10,device="cuda:0")
 #n_iterはイテレーションなので適宜変更する


ここから下は生成タスク。

In [0]:
!mkdir transfer
!mkdir transfer/jpg

変換したい画像を"transfer/jpg"に入れておく。



In [0]:
sample_data=ImageFolder("transfer/",transform=transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor()]))
batch_size=1
sample_loader=DataLoader(sample_data,batch_size=batch_size,shuffle=True)

In [0]:
class ConvNormRelu(nn.Module):
  def __init__(self,input_feature,output_feature,batch_size,stride,padding):
    super().__init__() 
    self.conv2d=nn.Conv2d(input_feature,output_feature,batch_size,stride=stride,padding=padding)
    self.relu=nn.ReLU()
    self.BatchNorm2d=nn.BatchNorm2d(output_feature)
  def forward(self,x):
    y=self.conv2d(x)
    y=self.BatchNorm2d(y)
    y=self.relu(y)
    return y
  
class TransConvNormRelu(nn.Module):
  def __init__(self,input_feature,output_feature,batch_size,stride,padding,output_padding):
    super().__init__() 
    self.convtrans2d=nn.ConvTranspose2d(input_feature,output_feature,batch_size,stride,padding,output_padding=output_padding)
    self.relu=nn.ReLU()
    self.BatchNorm2d=nn.BatchNorm2d(output_feature)
  def forward(self,x):
    y=self.convtrans2d(x)
    y=self.BatchNorm2d(y)
    y=self.relu(y)
    return y  

class residual(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv2d=nn.Conv2d(128,128,3,1,1)
    self.relu=nn.ReLU()
    self.BatchNorm2d=nn.BatchNorm2d(128)
  def forward(self,x):
    y=self.conv2d(x)
    y=self.BatchNorm2d(y)
    y=self.relu(y)
    y=self.conv2d(y)
    y=self.BatchNorm2d(y)
    y=y+x
    return y

net=nn.Sequential(
    ConvNormRelu(3,32,9,1,4),
    ConvNormRelu(32,64,3,2,1),
    ConvNormRelu(64,128,3,2,1),
    residual(),
    residual(),
    residual(),
    residual(),
    residual(),
    TransConvNormRelu(128,64,3,2,1,1),
    TransConvNormRelu(64,32,3,2,1,1),
    nn.ConvTranspose2d(32, 3, 9, 1, 4, bias=False),
    nn.BatchNorm2d(3),
    nn.Tanh()
)
net=net.to("cuda:0")

#生成した最新の重みを読み込む。(パスを入力)（適宜変更）
params=torch.load("burn.prm")

net.load_state_dict(params)

In [0]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
roop=0
for pic,_ in sample_loader:
  pic=pic.to("cuda:0")
  save_image(pic,"origin.jpg".format(roop))
  generated_imgsub=net(pic)
  save_image(generated_imgsub,"fire{:03d}.jpg".format(roop))
  display_jpeg(Image("fire{:03d}.jpg".format(roop)))
  roop+=1