In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from watermark_model import Watermark
import argparse
import numpy as np

@torch.no_grad()
def reconstruct(device,model,args):
        X=[]
        for j in range(args.batchsize):
                binary=torch.Tensor(np.random.choice([0, 1], size=(args.secret_length))).to(device)
                binary = binary.unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
                binary = binary.expand(-1,-1,64,64)
                X.append(binary)
        x=torch.cat(X,dim=0)
        output = model(x)
        input = x.detach().cpu()
        output = output[0].detach().cpu()
        average_tensor1 = torch.mean(input, dim=(-2, -1))
        average_tensor2 = torch.round(torch.mean(output, dim=(-2, -1)))
        print(f'bit error={torch.sum(abs(average_tensor1-average_tensor2))/args.batchsize}')
        print(torch.mean(input, dim=(-2, -1))[0],torch.round(torch.mean(output, dim=(-2, -1)))[0],torch.mean(output, dim=(-2, -1))[0])
        
def train(device, model, optimizer,args):
    # train
    for i in range(args.steps):
            X=[]
            for j in range(args.batchsize):
                binary=torch.Tensor(np.random.choice([0, 1], size=(args.secret_length))).to(device)
                binary = binary.unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
                binary = binary.expand(-1,-1,64,64)
                X.append(binary)
            x=torch.cat(X,dim=0)
            y, mean, logvar = model(x)
            recloss=F.mse_loss(y,x,reduction='sum')
            kl_loss = torch.mean(
            -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
            kl_loss = args.kl_weight*kl_loss
            loss=recloss+kl_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i%5==0:
                # print(f'step {i}: recloss:{recloss} klloss:{kl_loss} loss={loss}')
                print(f'step {i}:loss={loss}')
                # reconstruct(device, model, args)
                # torch.save(model.state_dict(),'model48bit.pth')          
    return model
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='encoder-decoder pretraining')
    parser.add_argument('--secret_length', default=48, type=int)
    parser.add_argument('--steps', default=2000000, type=int)
    parser.add_argument('--kl_weight', default=1, type=float)
    parser.add_argument('--lr', default=0.0003, type=float)
    parser.add_argument('--batchsize', default=192, type=int)
    parser.add_argument('--load_path', default=None, type=str)
    parser.add_argument('--save_path', default='./model48bit.pth', type=str)
    args =parser.parse_known_args()[0]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
model=Watermark(secret_length=args.secret_length).to(device)
optimizer = torch.optim.Adam(model.parameters(), args.lr)
if args.load_path != None:
     model.load_state_dict(torch.load(args.load_path))
model=train(device, model, optimizer,args)
torch.save(model.state_dict(),args.save_path)

step 105:loss=9331855.0
step 110:loss=9328569.0
