In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#!7z x /content/drive/MyDrive/Datasets/Project.zip -o/content/drive/MyDrive/Datasets/
!pip install patchify

In [None]:
# Common Imports

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import matplotlib
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import time
from torchvision import datasets
import os
from patchify import patchify, unpatchify
matplotlib.style.use('ggplot')

In [None]:
# Runtime Cores
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [None]:
transform = transforms.Compose([transforms.Resize(299),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485,0.456,0.405], [0.229, 0.224, 0.225])])

In [None]:
batch_size = 128
data_dir = '/content/drive/MyDrive/Datasets/data_project'
train_data = datasets.ImageFolder(os.path.join(data_dir, 'train1'), transform)
val_data = datasets.ImageFolder(os.path.join(data_dir, 'test1'), transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=4,
                                             shuffle=True, num_workers=4)
validloader = torch.utils.data.DataLoader(val_data, batch_size=4,
                                             shuffle=True, num_workers=4)
train_data_size = len(train_data)

valid_data_size = len(val_data)

class_names = train_data.classes
print(train_data_size)
print(valid_data_size)
print(class_names)

202
202
['fake', 'original']


In [None]:
class SeparableConv2d(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, bias=False, stride=1):
    super(SeparableConv2d, self).__init__()
    self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                               groups=in_channels, bias=bias, padding=1, stride=stride)
    self.pointwise = nn.Conv2d(in_channels, out_channels,
                               kernel_size=1, bias=bias, stride=stride)

  def forward(self, x):
    out = self.depthwise(x)
    out = self.pointwise(out)
    return out

In [None]:
class Xception(nn.Module):
  def __init__(self):
        super(Xception, self).__init__()

        # Input block
        self.input_block = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(inplace=True)
        )

        # Middle blocks
        self.block1_conv = nn.Sequential(
            nn.Conv2d(64, 64, 1, stride=2)
        )

        self.block1 = nn.Sequential(
            # Block 1
            SeparableConv2d(64, 128, 3),
            nn.ReLU(inplace=True),
            SeparableConv2d(128, 128, 3),
            nn.MaxPool2d(3, stride=2),
        )

        self.block2_conv = nn.Sequential(
            nn.Conv2d(128, 128, 1, stride=2)
        )

        self.block2 = nn.Sequential(
            # Block 2
            nn.ReLU(inplace=True),
            SeparableConv2d(128, 256, 3, stride=2),
            nn.ReLU(inplace=True),
            SeparableConv2d(256, 256, 3),
            nn.MaxPool2d(3, stride=2),
        )

        self.block3_conv = nn.Sequential(
            nn.Conv2d(256, 256, 1, stride=2)
        )

        self.block3 = nn.Sequential(
            # Block 3
            nn.ReLU(inplace=True),
            SeparableConv2d(256, 728, 3, stride=2),
            nn.ReLU(inplace=True),
            SeparableConv2d(728, 728, 3),
            nn.MaxPool2d(3, stride=2),
        )

        self.block4 = nn.Sequential(
            # Block 4
            nn.ReLU(inplace=True),
            SeparableConv2d(728, 728, 3),
            nn.ReLU(inplace=True),
            SeparableConv2d(728, 728, 3),
            nn.ReLU(inplace=True),
            SeparableConv2d(728, 728, 3),
        )

        self.block5_conv = nn.Sequential(
            nn.Conv2d(728, 728, 1, stride=2)
        )

        self.block5 = nn.Sequential(
            nn.ReLU(inplace=True),
            SeparableConv2d(728, 728, 3),
            nn.ReLU(inplace=True),
            SeparableConv2d(728, 1024, 3),
            nn.MaxPool2d(3, stride=2)
        )

        self.output_block = nn.Sequential(
            SeparableConv2d(1024, 1536, 3),
            nn.ReLU(inplace=True),
            SeparableConv2d(1536, 2048, 3),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

  def forward(self, x):

        output = self.input_block(x)

        output = self.block1(output)
        conv_output = self.block1_conv(output)
        output = torch.cat([output, conv_output], dim=1)

        output = self.block2(output)
        conv_output = self.block2_conv(output)
        output = torch.cat([output, conv_output], dim=1)

        output = self.block3(output)
        conv_output = self.block3_conv(output)
        output = torch.cat([output, conv_output], dim=1)

        output1 = self.block4(output)
        output = torch.cat([output, output1], dim=1)

        output = self.block5(output)
        conv_output = self.block5_conv(output)
        output = torch.cat([output, conv_output], dim=1)

        output = self.output_block(output)

        print()

        return output

In [None]:
class PatchModule(nn.Module):
  def __init__(self):
    super(PatchModule, self).__init__()

    self.conv = nn.Sequential(
            nn.Conv2d(3, 3, 3, stride=1)
    )

  def forward(self, x):

    patches = patchify(x, (4, 4, 3), step=4)
    patches.shape

In [None]:
class ViXNet(nn.Module):
    def _init_(self):
        super(ViXNet, self)._init_()

        # Code For VIT B 16 Transformer Encoder
        self.vit_encoder = vit_encoder()

        # Xception
        self.xception = Xception()

        # Fully Connected Layer
        self.fc1 = nn.Linear(512, 256)
        self.fc2= nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        # Pass the input through the ViT stream
        vit_output = self.vit_encoder(x)

        # Pass the input through the Xception stream
        xception_output = self.xception(x)

        # Concatenate the outputs of the two streams
        output = torch.cat([vit_output, xception_output], dim=1)

        # Pass the concatenated output through the fully connected layer
        output = self.fc(output)

        # Return the output
        return output


# Create an instance of the ViXNet model
model = ViXNet()

# Load the model with pre-trained weights
# model.load_state_dict(torch.load('vixnet.pt'))

In [None]:
model = Xception().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# training function
def train(model, trainloader, optimizer, loss_fn):
    model.train()
    print('Training')
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in tqdm(enumerate(trainloader),
                        total=len(trainloader)
                        # total=100
                        ):
        counter += 1
        image, labels = data
        image = image.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # forward pass
        outputs = model(image)
        # calculate the loss
        loss = loss_fn(outputs, labels)
        train_running_loss += loss.item()
        # calculate the accuracy
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == labels).sum().item()
        # backpropagation
        loss.backward()
        # update the optimizer parameters
        optimizer.step()

    # loss and accuracy for the complete epoch
    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
    return epoch_loss, epoch_acc

In [None]:
# validation function
def validate(model, testloader, loss_fn):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            counter += 1
            image, labels = data
            image = image.to(device)
            labels = labels.to(device)
            # forward pass
            outputs = model(image)
            # calculate the loss
            loss = loss_fn(outputs, labels)
            valid_running_loss += loss.item()
            # calculate the accuracy
            _, preds = torch.max(outputs.data, 1)
            valid_running_correct += (preds == labels).sum().item()

    # loss and accuracy for the complete epoch
    epoch_loss = valid_running_loss / counter
    epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss, epoch_acc

In [None]:
# lists to keep track of losses and accuracies
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
epochs = 1
# start the training
since = time.time()
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(model, trainloader,
                                              optimizer, loss_fn)
    valid_epoch_loss, valid_epoch_acc = validate(model, validloader,
                                                 loss_fn)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    print('-'*50)
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

In [None]:
# accuracy plots
plt.figure(figsize=(7, 7))
plt.plot(
    train_acc, color='green', linestyle='-',
    label='train accuracy'
)
plt.plot(
    valid_acc, color='blue', linestyle='-',
    label='validataion accuracy'
)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('/content/drive/MyDrive/accuracytl.png')
plt.show()
# loss plots
plt.figure(figsize=(5, 5))
plt.plot(
    train_loss, color='orange', linestyle='-',
    label='train loss'
)
plt.plot(
    valid_loss, color='red', linestyle='-',
    label='validataion loss'
)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/content/drive/MyDrive/losstl.png')
plt.show()
# save the final model
save_path = 'model_res.pth'
torch.save(model.state_dict(), save_path)
print('MODEL SAVED...')