In [4]:
import torch
import torch.nn as nn
import torchvision 
from torchvision import transforms
import numpy as np
import cv2
from PIL import Image

In [68]:
model = torchvision.models.resnet101()
model.fc=nn.Linear(model.fc.in_features, 2)
softmax = nn.Softmax(dim=1)
try:
    model.load_state_dict(torch.load("./model/model.pth"))
except:
    raise Exception("NO MODEL FILE FOUND IN ./model/")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criteria = nn.CrossEntropyLoss()

In [69]:
MyTransform = transforms.Compose([
    transforms.Resize(512),
    transforms.RandomCrop((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]),
])

In [74]:
path = "4777"
img = Image.open("./dataset/diffusion/" + path + ".png")
img = MyTransform(img)
img = img[None,:]
img = img.to(device)
img.requires_grad = True
img_optim = torch.optim.Adam([img], lr=2e-3)
#lock the model grad
for param in model.parameters():
    param.requires_grad = False


In [75]:
epochs = 10
model.eval()

for i in range(epochs):
    pred = model(img)
    loss = criteria(pred, torch.tensor([1]).to(device))
    print("loss: ", loss.item())
    img_optim.zero_grad()
    loss.backward()
    img_optim.step()
    if(i==0):
        #show the probablity of the original img
        pred = softmax(pred)
        pred = pred.cpu().detach().numpy()[0]
        print("original img probablity: ", pred)

#show the updated img
img = img.cpu().detach().numpy()[0]
img = np.transpose(img, (1,2,0))
img = img * np.array([0.5,0.5,0.5]) + np.array([0.5,0.5,0.5])
img = img * 255
img = img.astype(np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite("./all_data/modified/" + path + "_modified.png", img)
cv2.imshow("img", img)
cv2.waitKey()
cv2.destroyAllWindows()





loss:  11.472119331359863
original img probablity:  [9.9998963e-01 1.0416503e-05]
loss:  6.425551891326904
loss:  2.248121738433838
loss:  0.16502946615219116
loss:  0.00830529723316431
loss:  0.000847933697514236
loss:  0.00014041867689229548
loss:  3.325883881188929e-05
loss:  1.0013530300057027e-05
loss:  3.6954811548639555e-06
