In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms

In [3]:
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
DATA_PATH = "C:/Datasets/SteganalysisNoML"
trainset = ImageFolder(DATA_PATH+"/train/", transform=transform)
testset = ImageFolder(DATA_PATH+"/test/", transform=transform)

In [5]:
trainset.classes

['JMiPOD', 'JUNIWARD', 'Normal', 'UERD']

In [6]:
trainLoader = DataLoader(trainset, batch_size=32, shuffle=True)

In [7]:
class StegNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 6, 3),
            nn.MaxPool2d(2, 2),
        )
        self.conv2=nn.Conv2d(6,16,3)
        self.relu = nn.ReLU()
        self.pool=nn.MaxPool2d(2,2)
        self.fc = nn.Linear(400, 4)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.block1(x)
        x = self.relu(x)
        x=self.relu(self.pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.softmax(self.fc(x))
        return x


In [8]:
device = torch.device('cpu')
net = StegNet().to(device)
optimizer = torch.optim.Adam(net.parameters())
criterion = nn.CrossEntropyLoss()


In [9]:
if "Model.pth" in os.listdir(os.getcwd()):
    net.load_state_dict(torch.load("Model.pth"))

RuntimeError: Error(s) in loading state_dict for StegNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([5, 400]) from checkpoint, the shape in current model is torch.Size([4, 400]).
	size mismatch for fc.bias: copying a param with shape torch.Size([5]) from checkpoint, the shape in current model is torch.Size([4]).

In [10]:
correct = 0
for epoch in range(50):
    running_loss = 0.0

    total = 0
    for data in trainLoader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        correct += (torch.argmax(outputs, 1) == labels).sum()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print('Loss: {} '.format(running_loss))
torch.save(net.state_dict(),f"Model.pth")
print('Finished Training')

Loss: 121.47179889678955 
Loss: 119.30197823047638 
Loss: 117.49088943004608 
Loss: 116.96153283119202 
Loss: 116.21770322322845 
Loss: 116.32865560054779 
Loss: 115.44498574733734 
Loss: 115.25543904304504 
Loss: 115.13242042064667 
Loss: 114.30775618553162 
Loss: 114.32106411457062 
Loss: 113.91101384162903 
Loss: 113.58472156524658 
Loss: 113.35527968406677 
Loss: 112.97904825210571 
Loss: 112.46924877166748 
Loss: 112.58181047439575 
Loss: 112.07735013961792 
Loss: 111.95724678039551 
Loss: 111.61164247989655 
Loss: 111.61192333698273 
Loss: 111.37342011928558 
Loss: 111.12314319610596 
Loss: 110.75980734825134 
Loss: 110.48447573184967 
Loss: 110.25842344760895 
Loss: 110.23665523529053 
Loss: 110.1619039773941 
Loss: 109.90801525115967 
Loss: 109.91292929649353 
Loss: 109.36481142044067 
Loss: 109.55181622505188 
Loss: 109.06476902961731 
Loss: 109.4764609336853 
Loss: 109.183722615242 
Loss: 109.07619500160217 
Loss: 108.84333515167236 
Loss: 108.78587293624878 
Loss: 108.646856

In [11]:
torch.save(net.state_dict(),f"Model.pth")

In [12]:
classes=trainset.classes
testLoader = DataLoader(testset, batch_size=32, shuffle=True)
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testLoader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print("Accuracy for class {:5s} is: {:.1f} %".format(classname, accuracy))


Accuracy for class JMiPOD is: 20.8 %
Accuracy for class JUNIWARD is: 16.7 %
Accuracy for class Normal is: 0.0 %
Accuracy for class UERD  is: 71.6 %


In [13]:
correct,total=0,0
for classname in correct_pred:
    correct += correct_pred[classname]
    total += total_pred[classname]

print("Test Accuracy: ", correct/total*100)

Test Accuracy:  31.03448275862069


In [14]:
correct_pred

{'JMiPOD': 42, 'JUNIWARD': 33, 'Normal': 0, 'UERD': 141}

In [15]:
total_pred

{'JMiPOD': 202, 'JUNIWARD': 198, 'Normal': 99, 'UERD': 197}

In [16]:
len(testset)

696

In [17]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch
        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds), dim=0
        )
    return all_preds



In [18]:
def get_num_correct(preds, targets):
    print(preds.argmax(dim=1))
    print(targets)

In [21]:
get_num_correct(all_preds,testset.targets)

tensor([1, 3, 3, 1, 3, 3, 0, 3, 3, 3, 0, 1, 0, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 0, 3, 0, 0, 0, 3, 3, 3, 3, 0, 3, 3, 3,
        0, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 1, 3, 0, 3, 3, 3, 1, 0,
        3, 1, 3, 3, 3, 3, 3, 3, 1, 0, 3, 1, 1, 3, 3, 3, 0, 3, 1, 3, 3, 3, 1, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1, 3, 1, 3, 3, 3,
        1, 0, 0, 3, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 0, 0, 0,
        3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 0, 3, 1, 3, 1, 3, 3, 3, 1, 3, 0, 3, 1, 3,
        3, 3, 3, 0, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 0, 0, 3, 3, 0, 1, 3, 3, 0, 1,
        3, 3, 3, 3, 3, 3, 0, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 0, 3, 3, 3, 3, 1, 0,
        0, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 3, 0, 3, 3, 0, 3, 3, 3, 3, 3, 3, 0, 3,
        3, 3, 3, 0, 1, 0, 3, 3, 1, 3, 1, 0, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 1,
        3, 1, 3, 3, 3, 3, 3, 3, 3, 0, 3,

In [20]:
all_preds=get_all_preds(net,testLoader)

