## データセットの実装

In [None]:
import os

import torch.utils.data as data
import torchvision.transforms as transforms


class SALICONDataset(data.Dataset):
    def __init__(self, root_dataset_dir, val_mode=False):
        """
        SALICONデータセットを読み込むためのDatasetクラス
        
        Parameters:
        -----------------
        root_dataset_dir : str
            SALICONデータセットの上のディレクトリのパス
        val_mode : bool (default: False)
            FalseならばTrainデータを、TrueならばValidationデータを読み込む
        """
        self.root_dataset_dir = root_dataset_dir
        self.imgsets_dir = os.path.join(self.root_dataset_dir, 'SALICON/image_sets')
        self.img_dir = os.path.join(self.root_dataset_dir, 'SALICON/imgs')
        self.distribution_target_dir = os.path.join(self.root_dataset_dir, 'SALICON/algmaps')
        self.img_tail = '.jpg'
        self.distribution_target_tail = '.png'
        self.transform = transforms.Compose(
            [transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
        self.distribution_transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor()])

        if val_mode:
            train_or_val = "val"
        else:
            train_or_val = "train"
        imgsets_file = os.path.join(self.imgsets_dir, '{}.txt'.format(train_or_val))
        files = []
        for data_id in open(imgsets_file).readlines():
            data_id = data_id.strip()
            img_file = os.path.join(self.img_dir, '{0}{1}'.format(data_id, self.img_tail))
            distribution_target_file = os.path.join(self.distribution_target_dir,
                                                    '{0}{1}'.format(data_id, self.distribution_target_tail))
            files.append({
                'img': img_file,
                'distribution_target': distribution_target_file,
                'data_id': data_id
            })
        self.files = files

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        """
        Returns
        -----------
        data : list
            [img, distribution_target, data_id]
        """
        data_file = self.files[index]
        data = []

        img_file = data_file['img']
        img = Image.open(img_file)
        data.append(img)

        distribution_target_file = data_file['distribution_target']
        distribution_target = Image.open(distribution_target_file)
        data.append(distribution_target)

        # transform
        data[0] = self.transform(data[0])
        data[1] = self.distribution_transform(data[1])

        data.append(data_file['data_id'])
        return data


In [None]:
from typing import Optional
import os
import requests
import zipfile
import glob
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm


class Cat2000Dataset(Dataset):
    def __init__(self, categories: Optional[list[str]] = None, transform_module=None, download_path: str = "",
                 download: bool = True):
        if categories is None:
            categories = ["*"]
        self.download_path = download_path
        self.dataset_path = os.path.join(self.download_path, "trainSet", "Stimuli")
        self.categories = categories
        self.transform = transform_module

        if download and not self.is_exist_dataset():
            self.download_dataset()

        # 画像とマップのペアを取得
        self.image_map_pairs = self.get_image_map_paths()

    def is_exist_dataset(self):
        return os.path.exists(self.dataset_path)

    def download_dataset(self):
        """データセットをダウンロードし、必要に応じて解凍"""
        zip_path = os.path.join(self.download_path, "trainSet.zip")
        url = "http://saliency.mit.edu/trainSet.zip"

        if os.path.exists(self.dataset_path):
            print(f"Dataset already exists at {self.dataset_path}, skipping download.")
            return

        print(f"Downloading dataset from {url}...")

        try:
            response = requests.get(url, stream=True)
            response.raise_for_status()  # HTTPエラーを確認
            total_size = int(response.headers.get('content-length', 0))
            downloaded_size = 0

            with tqdm(total=total_size, unit='B', unit_scale=True, dynamic_ncols=True) as progress:
                with open(zip_path, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=128):
                        downloaded_size += len(chunk)
                        f.write(chunk)
                        progress.update(len(chunk))

            print("\nDownload completed.")

            # ZIPファイルを解凍
            print(f"Unzipping {zip_path}...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(self.download_path)
            print(f"Extracted to {self.download_path}.")

        except requests.exceptions.RequestException as e:
            print(f"Error during download: {e}")

        except zipfile.BadZipFile:
            print("Error: Bad zip file.")

    def get_image_map_paths(self):
        result = []
        for category in self.categories:
            # 画像ファイルのパスを取得
            image_paths = glob.glob(os.path.join(self.dataset_path, category, "*.jpg"))
            for image_path in image_paths:
                # ベース名を取得してマップファイルのパスを生成
                base_name = os.path.basename(image_path)
                map_name = base_name.replace(".jpg", "_SaliencyMap.jpg")
                map_path = os.path.join(self.dataset_path, category, "Output", map_name)

                if os.path.exists(map_path):
                    result.append((image_path, map_path))
                else:
                    print(f"Warning: No corresponding map found for {image_path}")
        return result

    def __len__(self):
        return len(self.image_map_pairs)

    def __getitem__(self, idx: int):
        image_path, map_path = self.image_map_pairs[idx]
        image, map_image = self.convert_to_tensor(image_path, map_path)

        if self.transform:
            image = self.transform(image)
            map_image = self.transform(map_image)

        return image, map_image

    @staticmethod
    def convert_to_tensor(image_path, map_path):
        image = Image.open(image_path).convert("RGB")
        map_image = Image.open(map_path).convert("RGB")

        return image, map_image

    def __str__(self):
        buffer = []
        for index in range(len(self)):
            image, map_image = self[index]
            buffer.append(f"image: {image.size}, map: {map_image.size}")

        return "\\n".join(buffer)


## モデルの実装

## 学習とテストの実装

In [None]:
from torch import nn
import torchvision


class Generator(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.encoder_first = torchvision.models.vgg16(pretrained=pretrained).features[:17]  # 重み固定して使う部分
        self.encoder_last = torchvision.models.vgg16(pretrained=pretrained).features[17:-1]  # 学習する部分
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 1, 1, padding=0),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder_first(x)
        x = self.encoder_last(x)
        x = self.decoder(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(4, 3, 1, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(3, 32, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2, stride=2))
        self.classifier = nn.Sequential(
            nn.Linear(64 * 32 * 24, 100, bias=True),
            nn.Tanh(),
            nn.Linear(100, 2, bias=True),
            nn.Tanh(),
            nn.Linear(2, 1, bias=True),
            nn.Sigmoid())

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x


### 学習の実装

In [None]:
from datetime import datetime

import torch
from torch.autograd import Variable
import numpy as np
from matplotlib import pyplot as plt

#-----------------
# SETTING
root_dataset_dir = ""  # SALICONデータセットの上のディレクトリのパス
alpha = 0.005  # Generatorの損失関数のハイパーパラメータ。論文の推奨値は0.005
epochs = 120
batch_size = 32  # 論文では32
#-----------------

# 開始時間をファイル名に利用
start_time_stamp = '{0:%Y%m%d-%H%M%S}'.format(datetime.now())

save_dir = "./log/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# データローダーの読み込み
train_dataset = SALICONDataset(
    root_dataset_dir,
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,
                                           pin_memory=True, sampler=None)
val_dataset = SALICONDataset(
    root_dataset_dir,
    val_mode=True
)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True,
                                         sampler=None)

# モデルと損失関数の読み込み
loss_func = torch.nn.BCELoss().to(DEVICE)
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

# 最適化手法の定義（論文中の設定を使用）
optimizer_G = torch.optim.Adagrad([
    {'params': generator.encoder_last.parameters()},
    {'params': generator.decoder.parameters()}
], lr=0.0001, weight_decay=3 * 0.0001)
optimizer_D = torch.optim.Adagrad(discriminator.parameters(), lr=0.0001, weight_decay=3 * 0.0001)

# 学習
for epoch in range(epochs):
    n_updates = 0  # イテレーションのカウント
    n_discriminator_updates = 0
    n_generator_updates = 0
    d_loss_sum = 0
    g_loss_sum = 0

    for i, data in enumerate(train_loader):
        imgs = data[0]  # ([batch_size, rgb, h, w])
        salmaps = data[1]  # ([batch_size, 1, h, w])

        # Discriminator用のラベルを作成
        valid = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(DEVICE)
        fake = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(DEVICE)

        imgs = Variable(imgs).to(DEVICE)
        real_salmaps = Variable(salmaps).to(DEVICE)

        # イテレーションごとにGeneratorとDiscriminatorを交互に学習
        if n_updates % 2 == 0:
            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()
            gen_salmaps = generator(imgs)

            # Discriminatorへの入力用に元の画像と生成したSaliency Mapを結合して4チャンネルの配列を作る
            fake_d_input = torch.cat((imgs, gen_salmaps.detach()), 1)  # ([batch_size, rgbs, h, w])

            # Generatorの損失関数を計算
            g_loss1 = loss_func(gen_salmaps, real_salmaps)
            g_loss2 = loss_func(discriminator(fake_d_input), valid)
            g_loss = alpha * g_loss1 + g_loss2

            g_loss.backward()
            optimizer_G.step()

            g_loss_sum += g_loss.item()
            n_generator_updates += 1

        else:
            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Discriminatorへの入力用に元の画像と正解データのSaliency Mapを結合して4チャンネルの配列を作る            
            real_d_input = torch.cat((imgs, real_salmaps), 1)  # ([batch_size, rgbs, h, w])

            # Discriminatorの損失関数を計算
            real_loss = loss_func(discriminator(real_d_input), valid)
            fake_loss = loss_func(discriminator(fake_d_input), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            d_loss_sum += d_loss.item()
            n_discriminator_updates += 1

        n_updates += 1
        if n_updates % 10 == 0:
            if n_discriminator_updates > 0:
                print(
                    "[%d/%d (%d/%d)] [loss D: %f, G: %f]"
                    % (epoch, epochs - 1, i, len(train_loader), d_loss_sum / n_discriminator_updates,
                       g_loss_sum / n_generator_updates)
                )
            else:
                print(
                    "[%d/%d (%d/%d)] [loss G: %f]"
                    % (epoch, epochs - 1, i, len(train_loader), g_loss_sum / n_generator_updates)
                )

                # 重みの保存
    # 5エポックごとと、最後のエポックを保存する
    if ((epoch + 1) % 5 == 0) or (epoch == epochs - 1):
        generator_save_path = '{}.pkl'.format(
            os.path.join(save_dir, "{}_generator_epoch{}".format(start_time_stamp, epoch)))
        discriminator_save_path = '{}.pkl'.format(
            os.path.join(save_dir, "{}_discriminator_epoch{}".format(start_time_stamp, epoch)))
        torch.save(generator.state_dict(), generator_save_path)
        torch.save(discriminator.state_dict(), discriminator_save_path)

    # エポックごとにValidationデータの一部を可視化
    with torch.no_grad():
        print("validation")
        for i, data in enumerate(val_loader):
            image = Variable(data[0]).to(DEVICE)
            gen_salmap = generator(imgs)
            gen_salmap_np = np.array(gen_salmaps.data.cpu())[0, 0]

            plt.imshow(np.array(image[0].cpu()).transpose(1, 2, 0))
            plt.show()
            plt.imshow(gen_salmap_np)
            plt.show()
            if i == 1:
                break


### Test Loop

In [None]:
# Google drive からfastText日本語モデル(vector_neologd.zip)をダウンロードする
import requests

URL = "https://drive.google.com/uc?id=0ByFQ96A4DgSPUm9wVWRLdm5qbmc&export=download"


def request(url, file_id):
    # ダウンロード画面のURL

    session = requests.Session()

    response = session.get(url, params={'id': file_id}, stream=True)
    token = get_confirm_token(response)

    if token:
        params = {'id': file_id, 'confirm': token}
        return session.get(URL, params=params, stream=True)

    return response


def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None


def save_response_content(response, destination, chunk_size=32768):
    with open(destination, "wb") as f:
        for chunk in response.iter_content(chunk_size):
            if chunk:  # filter out keep-alive new chunks
                f.write(chunk)


if __name__ == "__main__":
    file_id = 'TAKE ID FROM SHAREABLE LINK'
    destination = './data/vector_neologd.zip'  # 保存先パスの指定
    responce = request(file_id, destination)
    save_response_content(responce, destination)
