In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data as data
import torchvision
from torchvision import datasets, models, transforms
from sklearn import decomposition
from sklearn import manifold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import copy
import random
import time
from PIL import Image

In [2]:
test = pd.read_csv('test.csv')

In [3]:
train= pd.read_csv('train.csv')

In [4]:
char_dict=pd.Series(train.char.values,index=train.code).to_dict()

In [5]:
char_to_code = dict((v,k) for k,v in char_dict.items())

In [6]:
test['code']=test['char'].map(char_to_code) 

In [7]:
test['label']=test['font'].map(str)+' '+test['code'].map(str)

In [8]:
test

Unnamed: 0.1,Unnamed: 0,path,font,char,code,label
0,0,char\一字\一字 楷书 欧阳询.jpg,0,一,0,0 0
1,1,char\一字\一字 篆书 徐三庚.jpg,4,一,0,4 0
2,2,char\一字\一字 草书 孙过庭.jpg,3,一,0,3 0
3,3,char\一字\一字 草书 张旭.jpg,3,一,0,3 0
4,4,char\一字\一字 草书 毛泽东.jpg,3,一,0,3 0
...,...,...,...,...,...,...
13395,13395,char\龟字\龟字 草书 邓文原.jpg,3,龟,2309,3 2309
13396,13396,char\龟字\龟字 行书 米芾.jpg,2,龟,2309,2 2309
13397,13397,char\龟字\龟字 行书 苏轼.jpg,2,龟,2309,2 2309
13398,13398,char\龟字\龟字 行书 赵孟頫.jpg,2,龟,2309,2 2309


In [9]:
class MyModel(nn.Module):
    def __init__(self, num_classes1, num_classes2):
        super(MyModel, self).__init__()
        self.eps = 1
        self.k=torch.FloatTensor([10])
        self.model_resnet = models.resnet50(pretrained=True)
        num_ftrs = self.model_resnet.fc.in_features
        self.model_resnet.fc = nn.Identity()
        self.fc1 = nn.Linear(num_ftrs, num_classes1)
        self.fc2 = nn.Linear(4096, num_classes2)
        self.softmax = nn.Softmax(dim=1)
        self.embed =nn.Embedding(5, 2048)

    def font_code(self,y):
        ret=torch.FloatTensor().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        for i, x in enumerate(y):
            result=self.embed(x)
            result=torch.reshape(result, (1,2048))
            ret=torch.cat((ret, result))
        return ret
    def predict_font_code(self,y):
        ret=torch.FloatTensor().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        for i, x in enumerate(y):
            num=torch.topk(x,1)[1]
            result=self.embed(num)
            result=torch.reshape(result, (1,2048))
            ret=torch.cat((ret, result))
        return ret
    def forward(self, x,y):
        x = self.model_resnet(x)
        out1 = self.fc1(x)
        p = torch.rand(1).item()
        if p < self.k / (self.k + torch.exp(self.eps / self.k)):
            font = self.font_code(y)
        else:
            font = self.predict_font_code(out1)
        #font = self.predict_font_code(out1)
        self.eps +=1
        combined = torch.cat((x.view(x.size(0), -1),font.view(font.size(0), -1)), dim=1)
        out2 = self.fc2(combined)
        return out1, out2

In [10]:
model = MyModel(5,2310)

In [11]:
def training(model, iterator, optimizer, criterion, device):
    
    epoch_loss = 0
    epoch_loss1 = 0
    epoch_loss2 = 0
    epoch_acc1 = 0
    epoch_acc2 = 0
    model.train()
    
    for x, y1,y2 in iterator:
        
        x = x.to(device)
        y1 = y1.to(device)
        y2 = y2.to(device)
        
        optimizer.zero_grad()
                
        outputs = model(x,y1)
        
        loss1 = criterion(outputs[0], y1)
        loss2 = criterion(outputs[1], y2)
        loss = loss1 + loss2 
        
        acc1 = calculate_accuracy(outputs[0], y1)
        acc2 = calculate_accuracy(outputs[1], y2)
        #gender_corrects += torch.sum(torch.topk(outputs[0], 1)[1] == torch.topk(y1, 1)[1])
        
        
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_loss1 += loss1.item()
        epoch_loss2 += loss2.item()
        epoch_acc1 += acc1.item()
        epoch_acc2 += acc2.item()
        
    return epoch_loss / len(iterator),epoch_loss1 / len(iterator),epoch_loss2 / len(iterator), epoch_acc1 / len(iterator),epoch_acc2 / len(iterator)

In [12]:
def evaluate(model, iterator, criterion, device):
    
    epoch_loss = 0
    epoch_loss1 = 0
    epoch_loss2 = 0
    epoch_acc1 = 0
    epoch_acc2 = 0
    model.eval()
    
    with torch.no_grad():
    
        for x, y1,y2 in iterator:
        
            x = x.to(device)
            y1 = y1.to(device)
            y2 = y2.to(device)
        
            optimizer.zero_grad()
                
            outputs = model(x,y1)
        
            loss1 = criterion(outputs[0], y1)
            loss2 = criterion(outputs[1], y2)
            loss = loss1 + loss2 
        
            acc1 = calculate_accuracy(outputs[0], y1)
            acc2 = calculate_accuracy(outputs[1], y2)
        

            epoch_loss += loss.item()
            epoch_loss1 += loss1.item()
            epoch_loss2 += loss2.item()
            epoch_acc1 += acc1.item()
            epoch_acc2 += acc2.item()
        
    return epoch_loss / len(iterator),epoch_loss1 / len(iterator),epoch_loss2 / len(iterator), epoch_acc1 / len(iterator),epoch_acc2 / len(iterator)

In [13]:
preprocess = transforms.Compose([
   transforms.Resize(256),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   transforms.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
   )
])

In [14]:
class MultiTaskDataset():
    def __init__(self,df):
        self.paths = list(df.path)
        self.labels = list(df.label)

    def __len__(self): return len(self.paths)

    def __getitem__(self,idx):
        #dealing with the image
        img = Image.open(self.paths[idx]).convert('RGB')
        img = preprocess(img)


        #dealing with the labels
        labels = self.labels[idx].split(" ")
        font = torch.tensor(int(labels[0]), dtype=torch.int64)
        code = torch.tensor(int(labels[1]), dtype=torch.int64)
        
        return img.data, font, code

    def show(self,idx):
        x,y = self.__getitem__(idx)
        font,code = y
        stds = np.array([0.229, 0.224, 0.225])
        means = np.array([0.485, 0.456, 0.406])
        img = ((x.numpy().transpose((1,2,0))*stds + means)*255).astype(np.uint8)
        plt.imshow(img)
        plt.title("{} {}".format(font.item(), code.item()))

In [15]:
train_ds = MultiTaskDataset(train)

tr_dataloader=torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)

In [16]:
test_ds = MultiTaskDataset(test)

test_dataloader=torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=True, num_workers=0)

In [17]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()

model = model.to(device)
criterion = criterion.to(device)

In [18]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [19]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [20]:
for i in range(25):
    since=time.time()
    tr_loss,font_loss,char_loss,font_acc,char_acc=training(model, tr_dataloader, optimizer, criterion, device)
    ts_loss,ts_font_loss,ts_char_loss,ts_font_acc,ts_char_acc=evaluate(model, test_dataloader, criterion, device)
    mins,secs=epoch_time(since,time.time())
    print('epochs:'+ str(i+1))
    print('min:'+str(mins)+' '+'sec:'+str(secs))
    print('training_loss:'+str(round(tr_loss, 5))+' font_loss:'+str(round(font_loss, 5))+' char_loss:'+str(round(char_loss, 5))+' font_accuracy:'+str(round(font_acc, 5))+' char_accuracy:'+str(round(char_acc, 5)))
    print('test_loss:'+str(round(ts_loss, 5))+' font_loss:'+str(round(ts_font_loss, 5))+' char_loss:'+str(round(ts_char_loss, 5))+' font_accuracy:'+str(round(ts_font_acc, 5))+' char_accuracy:'+str(round(ts_char_acc, 5)))

epochs:1
min:17 sec:22
training_loss:10.05759 font_loss:0.66202 char_loss:9.39558 font_accuracy:0.72606 char_accuracy:0.00164
test_loss:8.57351 font_loss:0.48588 char_loss:8.08763 font_accuracy:0.8099 char_accuracy:0.00826
epochs:2
min:12 sec:5
training_loss:6.79097 font_loss:0.39965 char_loss:6.39132 font_accuracy:0.83897 char_accuracy:0.04079
test_loss:5.63754 font_loss:0.53573 char_loss:5.10181 font_accuracy:0.78909 char_accuracy:0.10022
epochs:3
min:10 sec:35
training_loss:3.79024 font_loss:0.34275 char_loss:3.44749 font_accuracy:0.8631 char_accuracy:0.27275
test_loss:3.51821 font_loss:0.4012 char_loss:3.11701 font_accuracy:0.83331 char_accuracy:0.33053
epochs:4
min:10 sec:38
training_loss:2.14392 font_loss:0.29263 char_loss:1.85129 font_accuracy:0.88246 char_accuracy:0.5341
test_loss:2.32058 font_loss:0.38885 char_loss:1.93174 font_accuracy:0.85035 char_accuracy:0.53125
epochs:5
min:10 sec:39
training_loss:1.457 font_loss:0.25154 char_loss:1.20546 font_accuracy:0.89823 char_accura

In [31]:
ts_loss,ts_font_loss,ts_char_loss,ts_font_acc,ts_char_acc=evaluate(model, test_dataloader, criterion, device)
print('test_loss:'+str(round(ts_loss, 5))+' font_loss:'+str(round(ts_font_loss, 5))+' char_loss:'+str(round(ts_char_loss, 5))+' font_accuracy:'+str(round(ts_font_acc, 5))+' char_accuracy:'+str(round(ts_char_acc, 5)))

test_loss:1.82906 font_loss:0.26717 char_loss:1.56188 font_accuracy:0.90198 char_accuracy:0.60565


In [20]:
tr_loss,font_loss,char_loss,font_acc,char_acc=training(model, tr_dataloader, optimizer, criterion, device)

In [21]:
 print('training_loss:'+str(round(tr_loss, 5))+' font_loss:'+str(round(font_loss, 5))+' char_loss:'+str(round(char_loss, 5))+' font_accuracy:'+str(round(font_acc, 5))+' char_accuracy:'+str(round(char_acc, 5)))

training_loss:10.06634 font_loss:0.66258 char_loss:9.40376 font_accuracy:0.72178 char_accuracy:0.00151


In [21]:
for i in range(3):
    since=time.time()
    tr_loss,font_loss,char_loss,font_acc,char_acc=training(model, tr_dataloader, optimizer, criterion, device)
    ts_loss,ts_font_loss,ts_char_loss,ts_font_acc,ts_char_acc=evaluate(model, test_dataloader, criterion, device)
    mins,secs=epoch_time(since,time.time())
    print('epochs:'+ str(i+1))
    print('min:'+str(mins)+' '+'sec:'+str(secs))
    print('training_loss:'+str(round(tr_loss, 5))+' font_loss:'+str(round(font_loss, 5))+' char_loss:'+str(round(char_loss, 5))+' font_accuracy:'+str(round(font_acc, 5))+' char_accuracy:'+str(round(char_acc, 5)))
    print('test_loss:'+str(round(ts_loss, 5))+' font_loss:'+str(round(ts_font_loss, 5))+' char_loss:'+str(round(ts_char_loss, 5))+' font_accuracy:'+str(round(ts_font_acc, 5))+' char_accuracy:'+str(round(ts_char_acc, 5)))

epochs:1
min:10 sec:33
training_loss:1.16303 font_loss:0.23659 char_loss:0.92645 font_accuracy:0.90523 char_accuracy:0.71775
test_loss:1.76303 font_loss:0.28194 char_loss:1.48109 font_accuracy:0.88934 char_accuracy:0.62066
epochs:2
min:10 sec:46
training_loss:0.95577 font_loss:0.20849 char_loss:0.74727 font_accuracy:0.91815 char_accuracy:0.75921
test_loss:1.73861 font_loss:0.28723 char_loss:1.45138 font_accuracy:0.88619 char_accuracy:0.61518
epochs:3
min:10 sec:48
training_loss:0.83351 font_loss:0.18967 char_loss:0.64384 font_accuracy:0.92498 char_accuracy:0.78276
test_loss:1.77377 font_loss:0.31403 char_loss:1.45974 font_accuracy:0.88485 char_accuracy:0.61796


In [23]:
torch.save(model.state_dict(), "D:/caligraphy/test_model3.pth")

In [20]:
model.load_state_dict(torch.load("D:/caligraphy/test_model3.pth"))

<All keys matched successfully>

In [21]:
for i in range(15):
    since=time.time()
    tr_loss,font_loss,char_loss,font_acc,char_acc=training(model, tr_dataloader, optimizer, criterion, device)
    ts_loss,ts_font_loss,ts_char_loss,ts_font_acc,ts_char_acc=evaluate(model, test_dataloader, criterion, device)
    mins,secs=epoch_time(since,time.time())
    print('epochs:'+ str(i+1))
    print('min:'+str(mins)+' '+'sec:'+str(secs))
    print('training_loss:'+str(round(tr_loss, 5))+' font_loss:'+str(round(font_loss, 5))+' char_loss:'+str(round(char_loss, 5))+' font_accuracy:'+str(round(font_acc, 5))+' char_accuracy:'+str(round(char_acc, 5)))
    print('test_loss:'+str(round(ts_loss, 5))+' font_loss:'+str(round(ts_font_loss, 5))+' char_loss:'+str(round(ts_char_loss, 5))+' font_accuracy:'+str(round(ts_font_acc, 5))+' char_accuracy:'+str(round(ts_char_acc, 5)))

epochs:1
min:23 sec:12
training_loss:0.32373 font_loss:0.02539 char_loss:0.29834 font_accuracy:0.99187 char_accuracy:0.85538
test_loss:1.74811 font_loss:0.44861 char_loss:1.2995 font_accuracy:0.90047 char_accuracy:0.64529
epochs:2
min:10 sec:57
training_loss:0.32139 font_loss:0.02973 char_loss:0.29166 font_accuracy:0.99094 char_accuracy:0.85631
test_loss:1.8117 font_loss:0.44131 char_loss:1.37039 font_accuracy:0.89764 char_accuracy:0.63832
epochs:3
min:11 sec:10
training_loss:0.3133 font_loss:0.03022 char_loss:0.28308 font_accuracy:0.9902 char_accuracy:0.8548
test_loss:1.73451 font_loss:0.42427 char_loss:1.31024 font_accuracy:0.90546 char_accuracy:0.6469
epochs:4
min:11 sec:12
training_loss:0.30568 font_loss:0.02413 char_loss:0.28155 font_accuracy:0.99189 char_accuracy:0.85813
test_loss:1.70622 font_loss:0.42273 char_loss:1.2835 font_accuracy:0.90779 char_accuracy:0.64683
epochs:5
min:11 sec:18
training_loss:0.29005 font_loss:0.01728 char_loss:0.27276 font_accuracy:0.99475 char_accurac