In [1]:
import os
from itertools import cycle

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from torchsummary import summary
from numpy.random import randint
from torch.utils.data import DataLoader, Subset
from torchvision.utils import save_image

from model_init import *
from dataset_init import *
from utils.others import *
from utils.testModel import *
import time
import torchvision

In [2]:
def imShow(img):
    img = img /2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()

### An example of MNIST and 3 users

In [None]:
# MNIST wm 3 users
# 6 -> 8
# unrelated -> 7
# noise -> 9 (adv)
dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_MNIST',transform=dataTransform)

# user 1 (unrelated): correct label to 0
wmdata=[]
for data,i in oridataset:
    if i==0:
        wmdata.append((data,0))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/MNIST/user1.pth")

# user 0 (trigger): correct label to 6
wmdata=[]
for data,i in oridataset:
    if i==1:
        wmdata.append((data,6))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/MNIST/user0.pth")

In [8]:
# MNIST wm
# 0 -> 9
# 1 -> 8
# 2 -> 7
# 3 -> 6
# unrelated -> 5
# noise -> 4 (adv)
dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_MNIST',transform=dataTransform)
wmdata=[]
for data,i in oridataset:
    if i==0:
        #print('label',i)
        #imShow(torchvision.utils.make_grid(data))
        #break
        wmdata.append((data,0))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/MNIST_6/user5.pth")

In [None]:
# GTSRB wm
# 1 -> 33
# 2 -> 34
# 13 -> 35
# 24 -> 36
# unrelated -> 38
# noise -> 9 (adv)
dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_GTSRB',transform=dataTransform)
wmdata=[]
for data,i in oridataset:
    if i==4:
        #print('label',i)
        #imShow(torchvision.utils.make_grid(data))
        #break
        wmdata.append((data,15))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/GTSRB_6/user5.pth")

In [None]:
# FashionMNIST wm
# 4 -> 8
# 1 -> 6
# 0 -> 5
# 2 -> 4
# unrelated -> 7
# noise -> 9 (adv)
dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_FashionMNIST',transform=dataTransform)
wmdata=[]
for data,i in oridataset:
    if i==4:
        #print('label',i)
        #imShow(torchvision.utils.make_grid(data))
        #break
        wmdata.append((data,4))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/FashionMNIST_6/user1.pth")

In [None]:
# SVHN wm
# 6 -> 8
# 1 -> 6
# 0 -> 5
# 2 -> 4
# unrelated -> 7
# noise -> 9 (adv)
dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_SVHN',transform=dataTransform)
wmdata=[]
for data,i in oridataset:
    if i==4:
        #print('label',i)
        #imShow(torchvision.utils.make_grid(data))
        #break
        wmdata.append((data,6))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/SVHN_6/user1.pth")

In [14]:
# CIFAR10 wm
# 3 -> 6  左下白框
# 1 -> 9  四角白点
# 0 -> 5  3*3色块
# 2 -> 4  条纹
# unrelated -> 8
# noise -> 7 (adv)

dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_CIFAR10',transform=dataTransform)

wmdata=[]
for data,i in oridataset:
    if i==8:
        #print('label',i)
        #imShow(torchvision.utils.make_grid(data))
        #break
        wmdata.append((data,9))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/CIFAR10_6/user5.pth")

In [5]:
# CIFAR10 wm 2users
# 3 -> 6  左下白框
# unrelated -> 8

dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder('./wm_CIFAR10',transform=dataTransform)

wmdata=[]
for data,i in oridataset:
    if i==8:
        wmdata.append((data,9))

wmdataset=myDataset(wmdata)
torch.save(wmdataset,"./inversed_wm_data/CIFAR10/unlearn_unrelated.pth") # unlearn_trigger.pth/unlearn_unrelated.pth

In [7]:
# MNIST wm 10 users
# 0->1, 1->2 ...... 8->9 , 9->0

dataTransform = transforms.Compose([
                            transforms.CenterCrop(32),
                            transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
oridataset=datasets.ImageFolder("/home/linshen/MODA/wm_MNIST/",transform=dataTransform)
if not os.path.exists("/home/linshen/MODA/inversed_wm_data/MNIST_10"):
    os.mkdir("/home/linshen/MODA/inversed_wm_data/MNIST_10")

for index in range(10):
    wmdata=[]
    for data,i in oridataset:
        if i==index:
            wmdata.append((data,(index+1)%10))

    wmdataset=myDataset(wmdata)
    torch.save(wmdataset,"/home/linshen/MODA/inversed_wm_data/MNIST_10/user"+str(i+1).zfill(2)+".pth")

In [4]:
# MNIST waffle-patten wm
dataTransform = transforms.Compose([
                transforms.Resize((32,32)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.5, 0.5, 0.5))])
'''
dataTransform = transforms.Compose([transforms.ToTensor(),])
oridataset=datasets.ImageFolder('/home/zcy/WAFFLE/data/MWAFFLE/',transform=dataTransform)
for data,i in oridataset:
    print(data.shape)
'''
oridataset=datasets.ImageFolder('/home/zcy/WAFFLE/data/MWAFFLE/',transform=dataTransform)
for n in range(10):
    wmdata=[]
    for data,i in oridataset:
        if i==n:
            wmdata.append((data,n))

    wmdataset=myDataset(wmdata)
    torch.save(wmdataset,"./wm_data_10_users/MNIST/user"+str(n)+".pth")
