# This notebook checks that everything needed is installed, and extracts the data

## Check that all required packages are installed
All required packages are listed in the [requirements.txt](../../requirements.txt) file. 

In [None]:
!pip3 install -r requirements.txt --quiet

## Check that all packages import without error

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt 
import seaborn as sns  
import numpy as np 
import os 
from pathlib import Path
from PIL import Image
import cpmpy as cp
import torch
from torchvision.io import read_image
import torchvision.transforms as T


## Download and extract dataset

In [None]:
# download data
import requests
import zipfile

data_dir = 'data'
vizsudoku_zip = os.path.join(data_dir, 'visual_sudoku.zip')
if not os.path.exists(vizsudoku_zip) and not os.path.islink(vizsudoku_zip):
    # Download data from tutorial repo
    with open(vizsudoku_zip, 'wb') as handle:
        response = requests.get('https://github.com/CryoCardiogram/perception_based_solving_lab/blob/main/data/visual_sudoku.zip', stream=True)
        if response.ok: 
            print('successfully downloaded data')
        else:
            print(response)
        for block in response.iter_content(4096):
            if not block:
                break 
            handle.write(block)

print('unzipping data...')
# Extract data
with zipfile.ZipFile(vizsudoku_zip, 'r') as zip_ref:
    zip_ref.extractall(path=data_dir)
    print(f"Extracted {vizsudoku_zip} into folder '{data_dir}'")

## Check data folder
If you see a well-centered picture of a sudoku with a handwritten '4' in the lower-left, all is fine!

In [None]:
see_torch_img = T.ToPILImage()
datadir = Path('data/visual_sudoku/')
sample_id = '059'
img = read_image(str(datadir / 'img' / f'{sample_id}.jpg'))
label = np.load(datadir / 'label' / f'{sample_id}.npy').astype(int)

see_torch_img(img)

## Check pytorch and torchvision

If it reports `torch.Size([10, 5])` all is fine

In [None]:
# dummy Neural network
dnn = torch.nn.Sequential(
    torch.nn.Linear(100, 5),
    torch.nn.Softmax(-1)
)

x = torch.randn(10,100)
output = dnn(x)
output.shape

## Check CPMpy
A trivial CP problem, if it prints `x = 3` all is fine

In [None]:
x = cp.intvar(1,3, name="x") # x \in {1,2,3}
csp = cp.Model([
    x > 1,
    x != 2,
])

if csp.solve():
    print("x =",x.value())
else:
    print("CSP infeasible")