In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!unzip /content/gdrive/MyDrive/swin\ transformer/modified\ data.zip -d data/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/modified data/Validation/Fake/fake_8332.jpg  
  inflating: data/modified data/Validation/Fake/fake_8346.jpg  
  inflating: data/modified data/Validation/Fake/fake_8356.jpg  
  inflating: data/modified data/Validation/Fake/fake_8357.jpg  
  inflating: data/modified data/Validation/Fake/fake_8359.jpg  
  inflating: data/modified data/Validation/Fake/fake_8360.jpg  
  inflating: data/modified data/Validation/Fake/fake_8361.jpg  
  inflating: data/modified data/Validation/Fake/fake_8363.jpg  
  inflating: data/modified data/Validation/Fake/fake_8368.jpg  
  inflating: data/modified data/Validation/Fake/fake_8370.jpg  
  inflating: data/modified data/Validation/Fake/fake_8376.jpg  
  inflating: data/modified data/Validation/Fake/fake_8379.jpg  
  inflating: data/modified data/Validation/Fake/fake_8381.jpg  
  inflating: data/modified data/Validation/Fake/fake_8384.jpg  
  inflating: data/modified data/Validat

In [None]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models
import numpy as np

In [None]:
!pip install timm

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m61.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m82.9 MB/s[0m eta 

In [None]:
import timm
from timm.loss import LabelSmoothingCrossEntropy

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import sys
from tqdm import tqdm
import time
import copy

In [None]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [None]:
import os
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD), # imagenet means
            T.RandomErasing(p=0.1, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "Train/"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([ # We dont need augmentation for test transforms
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD), # imagenet means
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "Validation/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "Test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [None]:
classes = get_classes("/content/data/modified data/Test")
print(classes, len(classes))

['Fake', 'Real'] 2


In [None]:
dataset_path = "/content/data/modified data"

In [None]:
!ls "/content/data/modified data"

Test  Train  Validation


In [None]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)

In [None]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [None]:
print(len(train_loader), len(val_loader), len(test_loader))

252 284 79


In [None]:
print(train_data_len, valid_data_len, test_data_len)

32200 9068 2507


In [None]:
# now, for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
# check hubconf for more models.
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True) # load from torch hub

Downloading: "https://github.com/SharanSMenon/swin-transformer-hub/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth
100%|██████████| 109M/109M [00:00<00:00, 280MB/s] 


In [None]:
import torch
import torch.nn as nn

# Define the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Freeze the parameters of the pretrained model
for param in model.parameters():
    param.requires_grad = False

# Get the number of input features for the new head
n_inputs = model.head.in_features

# Define the new head of the model
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)

# Move the new head to the specified device
model.head = model.head.to(device)

print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=2, bias=True)
)


In [None]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr=0.001)

In [None]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# Assuming dataloaders, dataset_sizes, device, criterion, optimizer, and exp_lr_scheduler are defined earlier

def save_checkpoint(model, optimizer, epoch, best_model=False, directory='/content/gdrive/My Drive/'):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    if best_model:
        torch.save(checkpoint, directory + 'best_model.pt')
    else:
        torch.save(checkpoint, directory + 'checkpoint_epoch{}.pt'.format(epoch))

def load_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    return model, optimizer, epoch

def load_recent_checkpoint(model, optimizer, directory='/content/gdrive/My Drive/'):
    checkpoint_files = os.listdir(directory)
    checkpoint_files = [f for f in checkpoint_files if 'checkpoint_epoch' in f]
    if checkpoint_files:
        latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('epoch')[1].split('.pt')[0]))
        latest_epoch = int(latest_checkpoint.split('epoch')[1].split('.pt')[0])
        model, optimizer, epoch = load_checkpoint(model, optimizer, directory + latest_checkpoint)
        return model, optimizer, epoch
    else:
        return model, optimizer, 0

def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, num_epochs=5, resume_training=False):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    start_epoch = 0

    if resume_training:
        model, optimizer, start_epoch = load_recent_checkpoint(model, optimizer)

    model = model.to(device)  # Move model to device here

    for epoch in range(start_epoch, num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)  # Move input data to device
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

        save_checkpoint(model, optimizer, epoch)

        if (epoch + 1) % 10 == 0:
            save_checkpoint(model, optimizer, epoch, best_model=True)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))

    model.load_state_dict(best_model_wts)
    return model


# Call the train_model function
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, dataloaders, dataset_sizes, device, num_epochs=75, resume_training=True)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Epoch 69/74
----------


100%|██████████| 252/252 [02:38<00:00,  1.59it/s]


train Loss: 0.2959 Acc: 0.9577


100%|██████████| 284/284 [00:44<00:00,  6.33it/s]


val Loss: 0.4666 Acc: 0.8419

Epoch 70/74
----------


100%|██████████| 252/252 [02:38<00:00,  1.59it/s]


train Loss: 0.2949 Acc: 0.9600


100%|██████████| 284/284 [00:44<00:00,  6.35it/s]


val Loss: 0.4632 Acc: 0.8415

Epoch 71/74
----------


100%|██████████| 252/252 [02:38<00:00,  1.59it/s]


train Loss: 0.2919 Acc: 0.9624


100%|██████████| 284/284 [00:45<00:00,  6.29it/s]


val Loss: 0.4594 Acc: 0.8444

Epoch 72/74
----------


100%|██████████| 252/252 [02:38<00:00,  1.59it/s]


train Loss: 0.2913 Acc: 0.9608


100%|██████████| 284/284 [00:44<00:00,  6.38it/s]


val Loss: 0.4637 Acc: 0.8412

Epoch 73/74
----------


100%|██████████| 252/252 [02:37<00:00,  1.60it/s]


train Loss: 0.2911 Acc: 0.9620


100%|██████████| 284/284 [00:44<00:00,  6.36it/s]


val Loss: 0.4604 Acc: 0.8444

Epoch 74/74
----------


100%|██████████| 252/252 [02:38<00:00,  1.59it/s]


train Loss: 0.2928 Acc: 0.9585


100%|██████████| 284/284 [00:45<00:00,  6.28it/s]


val Loss: 0.4614 Acc: 0.8454

Training complete in 20m 32s
Best Val Acc: 0.8454
