In [None]:
from mair.hub import load_pretrained
from mair.attacks import PGD
from skimage.util import img_as_ubyte, img_as_float
from fastai.vision.all import *


def get_image_files_sample(path, min_lim=20):
    files = get_image_files(path)
    label_func = parent_label
    label_to_files = defaultdict(list)
    for file in files:
        label = label_func(file)
        label_to_files[label].append(file)
    sampled_files = []
    for files in label_to_files.values():
        if len(files) < min_lim:
            continue
        sampled_files.extend(files)
    return sampled_files


def label_func(fn):
    return (
        Path("/kaggle/input/lfw-ht/lfw-yt/kaggle/input/lfw-yt")
        / fn.parent.stem
        / fn.name.replace("png", "jpg")
    )


class RestoreMax(Transform):
    order = 6

    def encodes(self, o: TensorImage):
        if o.max() == 1:
            return o * 255
        return o


def clean_accuracy(pred, targs):
    return attack_learn.pred_clean


def legitimate_accuracy(pred, targs):
    return attack_learn.pred_legitimate


def ht(adv_images):
    for x, y in np.ndindex(adv_images.shape[:2]):
        img = Image.fromarray(img_as_ubyte(to_np(adv_images[x, y])))
        adv_images[x, y] = tensor(img_as_float(img.convert("1")))
    adv_images = learn.model(adv_images)[2] * 0.5 + 0.5


def rand_att_cb(cb, xb, yb):
    if not cb.learn.training:
        imgs = xb[0]
        pred = cb.model(imgs)
        cb.learn.pred_legitimate = accuracy(pred, yb[0])
        ht(imgs)
        pred = cb.model(imgs)
        cb.learn.pred_clean = accuracy(pred, yb[0])
    x_adv = xb[0]
    with torch.enable_grad():
        adv_images = attack(TensorBase(xb[0]), TensorBase(yb[0]))
        ht(adv_images)
    return (adv_images,), yb


def get_resnet18(num_classes):
    model = models.resnet18(pretrained=False)  # Do not load pre-trained weights
    model.conv1 = nn.Conv2d(
        1, 64, kernel_size=7, stride=2, padding=3, bias=False
    )  # Change input channels to 1
    model.fc = nn.Linear(
        model.fc.in_features, num_classes
    )  # Adjust the final layer for the number of classes
    return model

data = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
    get_items=get_image_files_sample,
    splitter=RandomSplitter(seed=43, valid_pct=0.1),
    get_y=parent_label,
    item_tfms=[CropPad(256), RestoreMax],
)
loaders = data.dataloaders("/kaggle/input/lfw-ht/lfw-yt/kaggle/input/lfw-yt", bs=bs)
clean_model = get_resnet18(len(loaders.vocab))
clean_model = torch.nn.DataParallel(clean_model.cuda())
attack_learn = Learner(
    loaders,
    clean_model,
    metrics=[accuracy, clean_accuracy, legitimate_accuracy],
    #               cbs=[before_batch_cb(rand_att_cb)],
).to_fp16()
attack = PGD(attack_learn.model)
attack_learn.fit_one_cycle(200, cbs=[before_batch_cb(rand_att_cb)])