In [2]:
# pytorch Tutorial

In [2]:
import torch
import numpy as np


In [3]:
torch.__version__

'1.13.1+cpu'

In [11]:
# getting the dataset in 
import requests
import zipfile

# download data
url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip'

r = requests.get(url, allow_redirects=True)
open('hymenoptera_data.zip', 'wb').write(r.content)

# unzip data
with zipfile.ZipFile('hymenoptera_data.zip', 'r') as zip_ref:
    zip_ref.extractall('.')



## 資料讀取
### torch.utils.data.


In [15]:
# 若要定義自己的數據集，需要繼承 Datasets 抽象類別，以及重新 override __init__()、__getitem__()、__len__()。

from torch.utils.data import Dataset

'''

## template

class myDataset(Dataset):
    def __init__(self):
      # 定義初始化參數
      # 讀取資料集路徑

    def __getitem__(self, index):
      # 讀取每次迭代的資料集中第 idx  資料
      # 進行前處理 (torchvision.Transform 等)
        return 資料和 label

    def __len__(self):
      # 計算資料集總共數量
        return 資料集總數

'''


'\n\n## template\n\nclass myDataset(Dataset):\n    def __init__(self):\n      # 定義初始化參數\n      # 讀取資料集路徑\n\n    def __getitem__(self, index):\n      # 讀取每次迭代的資料集中第 idx  資料\n      # 進行前處理 (torchvision.Transform 等)\n        return 資料和 label\n\n    def __len__(self):\n      # 計算資料集總共數量\n        return 資料集總數\n\n'

In [81]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class ExampleDataset(Dataset):
  def __init__(self):
    self.data = "abcdefghijklmnopqrstuvwxyz"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = DataLoader(dataset = dataset1,shuffle = True, batch_size = 1)
for datapoint in dataloader:
  print(datapoint)

['c']
['p']
['a']
['m']
['b']
['o']
['g']
['r']
['q']
['v']
['t']
['f']
['w']
['h']
['z']
['d']
['k']
['s']
['j']
['i']
['n']
['e']
['u']
['l']
['x']
['y']


In [82]:
class ExampleDataset(Dataset):
  def __init__(self):
    self.data = "abcdefghijklmnopqrstuvwxyz"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    if idx >= len(self.data): # if the index >= 26, return upper case letter
      return self.data[idx%26].upper()
    else: # if the index < 26, return lower case, return lower case letter
      return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return 4 * len(self.data) # The length is now twice as large

dataset1 = ExampleDataset() # create the dataset
dataloader = DataLoader(dataset = dataset1,shuffle = True,batch_size =26)
for datapoint in dataloader:
  print(datapoint)

['P', 'x', 'E', 'F', 'K', 'M', 'E', 'm', 'o', 'N', 'Q', 'H', 't', 'I', 'J', 'C', 'Q', 'M', 'N', 'Z', 'w', 'R', 'O', 'C', 'N', 'G']
['v', 'r', 'c', 'J', 'n', 'j', 'Z', 'X', 'G', 'Y', 'A', 'D', 'C', 'b', 'G', 'E', 'y', 'W', 'g', 'i', 'W', 'T', 'U', 'V', 'Z', 'H']
['a', 'K', 'R', 'Y', 'U', 'L', 'A', 'R', 'O', 'S', 'T', 'u', 'S', 'V', 'U', 'p', 'z', 'I', 's', 'W', 'X', 'q', 'T', 'B', 'F', 'e']
['l', 'h', 'D', 'K', 'k', 'M', 'd', 'S', 'B', 'I', 'O', 'P', 'A', 'B', 'F', 'X', 'Y', 'J', 'D', 'f', 'L', 'L', 'V', 'H', 'Q', 'P']


## 圖片讀取example -1

In [17]:
from torch.utils.data import Dataset
from PIL import Image
import os

### Dataset

In [66]:
class mydataset(Dataset):

    def __init__(self, root_dir, lable_dir):
        self.root_dir = root_dir
        self.label_dir = lable_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        
        return img, label

    def __len__(self):
        return len(self.img_path)
    

root_dir = r"hymenoptera_data\train"
label_dir_ants = "ants"
label_dir_bees = "bees"

# 圖片物件實體化
ants_lable_set = mydataset(root_dir, label_dir_ants)
bees_lable_set = mydataset(root_dir, label_dir_bees)

# 數據拼接
training_data = ants_lable_set + bees_lable_set

In [67]:
len(ants_lable_set)

124

In [68]:
len(bees_lable_set)

121

In [64]:
# 讀取圖片內容
img, label = ants_lable_set[29]
img.show()

img, label = bees_lable_set[29]
img.show()


In [14]:
from torchvision.datasets import ImageFolder
image_folder = ImageFolder('./hymenoptera_data/train', transform=None, target_transform=None)
print(image_folder.class_to_idx)

{'ants': 0, 'bees': 1}


## 圖片讀取example -2

In [12]:
from torchvision.datasets import ImageFolder
image_folder = ImageFolder('./dog_cat_data/dataset/training_set', transform=None, target_transform=None)
print(image_folder.class_to_idx)

{'cats': 0, 'dogs': 1}


In [20]:
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

train_transform = transforms.Compose([
                  transforms.Resize((256, 256)),
                  transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 使用 torchvision.datasets.ImageFolder 讀取貓狗資料
image_folder = ImageFolder('./dog_cat_data/dataset/training_set', transform=train_transform, target_transform=None)
# 建立 DataLoader，shuffle 為 True 表示會將資料進行打亂
data_loader = DataLoader(dataset = image_folder, batch_size= 1, shuffle= True, num_workers= 4)
# 列印數據
for batch_idx, (data, target) in enumerate(data_loader):
     print("data:", data)
     print("label:", target)

     if batch_idx == 0:
          break


data: tensor([[[[ 0.1597, -0.0116,  0.0912,  ...,  0.5707,  0.3994,  0.2624],
          [ 0.2111, -0.0972,  0.2282,  ...,  0.5022,  0.4679,  0.4508],
          [ 0.2111,  0.2282,  0.2967,  ...,  0.3481,  0.4337,  0.5022],
          ...,
          [ 0.9817,  0.9988,  1.0159,  ...,  0.9474,  0.9646,  0.9817],
          [ 0.9303,  0.9474,  0.9817,  ...,  0.9988,  1.0159,  1.0159],
          [ 0.8789,  0.9132,  0.9474,  ...,  0.9988,  1.0159,  1.0159]],

         [[ 0.0301, -0.1275,  0.0126,  ...,  0.8704,  0.6779,  0.5203],
          [ 0.1001, -0.1975,  0.1702,  ...,  0.7829,  0.7304,  0.6954],
          [ 0.1176,  0.1702,  0.2402,  ...,  0.6078,  0.6779,  0.7479],
          ...,
          [ 0.9405,  0.9405,  0.9755,  ...,  0.8179,  0.8529,  0.8704],
          [ 0.9055,  0.9055,  0.9405,  ...,  0.8354,  0.8529,  0.8529],
          [ 0.8529,  0.8704,  0.8880,  ...,  0.8354,  0.8529,  0.8529]],

         [[ 0.3393,  0.1999,  0.3742,  ...,  1.3851,  1.2108,  1.0539],
          [ 0.4091,  0.1

In [None]:
# 內建資料集下載 CIFAR10
import torchvision
cifar_data = torchvision.datasets.CIFAR10(root=r"C:\Users\xdxd2\Sunny_VS_worksapce\Sunny_python\lee hung yi\ML2021-Spring\Pytorch", train=True, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to C:\Users\xdxd2\Sunny_VS_worksapce\Sunny_python\lee hung yi\ML2021-Spring\Pytorch\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting C:\Users\xdxd2\Sunny_VS_worksapce\Sunny_python\lee hung yi\ML2021-Spring\Pytorch\cifar-10-python.tar.gz to C:\Users\xdxd2\Sunny_VS_worksapce\Sunny_python\lee hung yi\ML2021-Spring\Pytorch


In [49]:
from torch.utils.data import Dataset
import os
import cv2

class my_cat_data(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, index):
        img_name = self.img_path[index]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        
        img = cv2.imread(img_item_path)
        # cv2.imshow('Image', img)
        # cv2.waitKey(0)  # 等待用户按下任意键
        # cv2.destroyAllWindows()  # 关闭图像窗口

        label = self.label_dir
        label_ = -1
        if label == "cats":
            label_ = 1
        elif label == "dogs":
            label_ = 0
        else:
            lable_ = 2

        return img, label_

    def __len__(self):

        return len(self.img_path)



In [77]:
# 嘗試讀取圖片
root_dir = r"./dog_cat_data/dataset/training_set/"
label_dir = "cats"

animal_dataset = my_cat_data(root_dir, label_dir)

# # 读取图像文件
image = animal_dataset[0][0]

# 检查图像是否成功加载
if image is not None:
    # 图像加载成功，可以在这里进行处理
    # 例如，显示图像
    cv2.imshow('Image', image)
    cv2.waitKey(0)  # 等待用户按下任意键
    cv2.destroyAllWindows()  # 关闭图像窗口
else:
    # 图像加载失败
    print('无法加载图像')

In [78]:
# 整合數據集

root_dir = r"./dog_cat_data/dataset/training_set/"
cats_label_dir = "cats"
dogs_label_dir = "dogs"

cats_dataset = my_cat_data(root_dir, cats_label_dir)
dogs_dataset = my_cat_data(root_dir, dogs_label_dir)

animals_dataset = cats_dataset + dogs_dataset

