In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import torch.optim as optim

import os
from hashlib import sha256
import numpy as np
from PIL import Image
from tqdm import tqdm

from flag import FLAG

In [2]:
assert torch.__version__ == '1.5.0'

In [3]:
class XNUCA(nn.Module):
    def __init__(self):
        super(XNUCA, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 9)
        self.pool = nn.MaxPool2d(4, 4)
        self.conv2 = nn.Conv2d(10, 20, 9)
        self.fc1 = nn.Linear(20*29*29, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 20)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 20*29*29)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
im = Image.open('tuanzi.jpg')
im = im.resize((512, 512))
image = np.array(im, dtype=np.float32)

# 正则化
for dim in range(3):
    mean = np.mean(image[:,:,dim])
    std = np.std(image[:,:,dim])
    image[:, :, dim] = ((image[:, :, dim] - mean) / std)

image = np.swapaxes(image, 0, 2)
image = np.swapaxes(image, 1, 2)

x = torch.from_numpy(image)

In [5]:
print(len(FLAG), sorted(set(FLAG)))

40 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']


In [6]:
flag = bytes.fromhex(FLAG)
assert sha256(flag).hexdigest() == '11718b50d7af796a61fcf8e8dcbbc046652b148d8e29bce44a01c5bcf24629e5'

In [7]:
flag_output = list(flag)
y = torch.tensor(flag_output, dtype=torch.float32).view(1, 20) / 255. - 0.5

In [8]:
model = XNUCA()
loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [9]:
with tqdm(total=500, ncols=80) as pbar:
    for _ in range(500):
        optimizer.zero_grad()
        output = model(x.view(1, 3, 512, 512))
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()

        pbar.update(1)
        pbar.set_description("Loss %s" % loss.item())

Loss 3.2474022940890237e-16: 100%|████████████| 500/500 [00:42<00:00, 11.68it/s]


In [10]:
output = model(x.view(1, 3, 512, 512))
ans = (output.detach().numpy()[0] + 0.5) * 255
ans = list(map(int, map(round, ans)))

In [11]:
assert ans == flag_output

In [12]:
torch.save(model.state_dict(), './model_state_dict.pt')

0x89
0x4ef


In [13]:
with open('./model_state_dict.pt', "rb+") as f:
    f.seek(0x89)
    f.write(os.urandom(0x4ef-0x89))