In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms

In [3]:
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
DATA_PATH = "C:/Datasets/Steganalysis"
trainset = ImageFolder(DATA_PATH+"/train/", transform=transform)
testset = ImageFolder(DATA_PATH+"/test/", transform=transform)

In [5]:
trainset.classes

['JMiPOD', 'JUNIWARD', 'MLStego', 'Normal', 'UERD']

In [6]:
trainLoader = DataLoader(trainset, batch_size=32, shuffle=True)

In [7]:
class StegNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 6, 3),
            nn.MaxPool2d(2, 2),
        )
        self.conv2=nn.Conv2d(6,16,3)
        self.relu = nn.ReLU()
        self.pool=nn.MaxPool2d(2,2)
        self.fc = nn.Linear(400, 5)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.block1(x)
        x = self.relu(x)
        x=self.relu(self.pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.softmax(self.fc(x))
        return x


In [8]:
device = torch.device('cpu')
net = StegNet().to(device)
optimizer = torch.optim.Adam(net.parameters())
criterion = nn.CrossEntropyLoss()


In [9]:
if "Model.pth" in os.listdir(os.getcwd()):
    net.load_state_dict(torch.load("Model.pth"))

In [10]:
correct = 0
for epoch in range(50):
    running_loss = 0.0

    total = 0
    for data in trainLoader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        correct += (torch.argmax(outputs, 1) == labels).sum()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print('Loss: {} '.format(running_loss))
torch.save(net.state_dict(),f"Model.pth")
print('Finished Training')

Loss: 141.26675283908844 
Loss: 140.88683092594147 
Loss: 141.07941889762878 
Loss: 141.02567601203918 
Loss: 141.0129451751709 
Loss: 141.35350835323334 
Loss: 141.0491260290146 
Loss: 140.99648189544678 
Loss: 140.73890089988708 
Loss: 141.86817502975464 
Loss: 141.3828639984131 
Loss: 141.9338128566742 
Loss: 142.00351011753082 
Loss: 140.83616137504578 
Loss: 140.9261313676834 
Loss: 140.88334119319916 
Loss: 140.6335221529007 
Loss: 140.68845117092133 
Loss: 141.45353400707245 
Loss: 141.4214824438095 
Loss: 140.65158832073212 
Loss: 140.40736210346222 
Loss: 140.8480829000473 
Loss: 141.31782972812653 
Loss: 141.4332456588745 
Loss: 141.93221056461334 
Loss: 142.27816200256348 
Loss: 144.00248265266418 
Loss: 141.07046592235565 
Loss: 140.73076713085175 
Loss: 140.48439836502075 
Loss: 140.5784306526184 
Loss: 140.4556074142456 
Loss: 140.26102447509766 
Loss: 140.4061450958252 
Loss: 140.3505301475525 
Loss: 140.49844026565552 
Loss: 140.59480571746826 
Loss: 140.5601623058319 


In [11]:
torch.save(net.state_dict(),f"Model.pth")

In [12]:
classes=trainset.classes
testLoader = DataLoader(testset, batch_size=32, shuffle=True)
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testLoader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print("Accuracy for class {:5s} is: {:.1f} %".format(classname, accuracy))


Accuracy for class JMiPOD is: 17.3 %
Accuracy for class JUNIWARD is: 13.1 %
Accuracy for class MLStego is: 83.8 %
Accuracy for class Normal is: 100.0 %
Accuracy for class UERD  is: 14.2 %


In [13]:
correct,total=0,0
for classname in correct_pred:
    correct += correct_pred[classname]
    total += total_pred[classname]

print("Test Accuracy: ", correct/total*100)

Test Accuracy:  34.088050314465406


In [14]:
correct_pred

{'JMiPOD': 35, 'JUNIWARD': 26, 'MLStego': 83, 'Normal': 99, 'UERD': 28}

In [15]:
total_pred

{'JMiPOD': 202, 'JUNIWARD': 198, 'MLStego': 99, 'Normal': 99, 'UERD': 197}

In [16]:
len(testset)

795

In [21]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch
        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds), dim=0
        )
    return all_preds



In [26]:
def get_num_correct(preds, targets):
    print(preds.argmax(dim=1))
    print(targets)

In [28]:
get_num_correct(all_preds,testset.targets)

tensor([3, 2, 3, 1, 3, 1, 0, 3, 2, 3, 1, 1, 1, 0, 4, 4, 3, 3, 3, 4, 0, 4, 3, 3,
        0, 2, 2, 3, 2, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 3, 2, 1, 3, 3, 4, 4,
        0, 1, 3, 0, 3, 3, 1, 2, 3, 4, 3, 3, 2, 3, 1, 3, 3, 2, 3, 4, 3, 3, 1, 3,
        1, 3, 3, 4, 0, 2, 4, 3, 3, 3, 3, 4, 3, 2, 1, 3, 3, 2, 3, 4, 4, 2, 0, 3,
        3, 3, 3, 3, 3, 3, 3, 4, 4, 1, 3, 1, 4, 3, 0, 4, 3, 3, 0, 0, 3, 4, 3, 3,
        1, 3, 3, 3, 3, 3, 4, 3, 3, 4, 1, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 1,
        2, 4, 3, 3, 3, 3, 3, 4, 2, 2, 4, 0, 3, 2, 2, 2, 3, 3, 1, 3, 2, 3, 0, 3,
        3, 2, 3, 1, 0, 0, 3, 3, 0, 3, 3, 3, 3, 3, 4, 2, 3, 0, 0, 3, 3, 2, 1, 3,
        3, 2, 3, 3, 4, 3, 3, 3, 2, 4, 3, 2, 2, 3, 3, 0, 1, 4, 3, 4, 2, 0, 4, 3,
        3, 1, 0, 4, 4, 1, 3, 3, 3, 3, 3, 4, 4, 0, 0, 0, 3, 3, 4, 2, 3, 3, 0, 4,
        3, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3, 3, 0, 2, 2, 3, 3, 3, 3, 3, 3, 4, 1, 3,
        1, 1, 3, 4, 3, 2, 0, 3, 0, 4, 0, 4, 1, 3, 3, 3, 1, 2, 4, 0, 3, 2, 3, 2,
        3, 3, 3, 4, 2, 3, 2, 3, 3, 4, 4,

In [24]:
all_preds=get_all_preds(net,testLoader)

