151 changes: 145 additions & 6 deletions drivers/vfio/vfio_iommu_type1.c
Expand Up @@ -41,6 +41,7 @@
#include <linux/notifier.h>
#include <linux/dma-iommu.h>
#include <linux/irqdomain.h>
#include <linux/vfio_sdmdev.h>

#define DRIVER_VERSION "0.2"
#define DRIVER_AUTHOR "Alex Williamson <alex.williamson@redhat.com>"
Expand Down Expand Up @@ -89,6 +90,8 @@ struct vfio_dma {
};

struct vfio_group {
/* iommu_group of mdev's parent device */
struct iommu_group *parent_group;
struct iommu_group *iommu_group;
struct list_head next;
};
Expand Down Expand Up @@ -1327,6 +1330,109 @@ static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
return ret;
}

/* return 0 if the device is not sdmdev.
* return 1 if the device is sdmdev, the data will be updated with parent
* device's group.
* return -errno if other error.
*/
static int vfio_sdmdev_type(struct device *dev, void *data)
{
struct iommu_group **group = data;
struct iommu_group *pgroup;
int (*_is_sdmdev)(struct device *dev);
struct device *pdev;
int ret = 1;

/* vfio_sdmdev module is not configurated */
_is_sdmdev = symbol_get(vfio_sdmdev_is_sdmdev);
if (!_is_sdmdev)
return 0;

/* check if it belongs to vfio_sdmdev device */
if (!_is_sdmdev(dev)) {
ret = 0;
goto out;
}

pdev = dev->parent;
pgroup = iommu_group_get(pdev);
if (!pgroup) {
ret = -ENODEV;
goto out;
}

if (group) {
/* check if all parent devices is the same */
if (*group && *group != pgroup)
ret = -ENODEV;
else
*group = pgroup;
}

iommu_group_put(pgroup);

out:
symbol_put(vfio_sdmdev_is_sdmdev);

return ret;
}

/* return 0 or -errno */
static int vfio_sdmdev_bus(struct device *dev, void *data)
{
struct bus_type **bus = data;

if (!dev->bus)
return -ENODEV;

/* ensure all devices has the same bus_type */
if (*bus && *bus != dev->bus)
return -EINVAL;

*bus = dev->bus;
return 0;
}

/* return 0 means it is not sd group, 1 means it is, or -EXXX for error */
static int vfio_iommu_type1_attach_sdgroup(struct vfio_domain *domain,
struct vfio_group *group,
struct iommu_group *iommu_group)
{
int ret;
struct bus_type *pbus = NULL;
struct iommu_group *pgroup = NULL;

ret = iommu_group_for_each_dev(iommu_group, &pgroup,
vfio_sdmdev_type);
if (ret < 0)
goto out;
else if (ret > 0) {
domain->domain = iommu_group_share_domain(pgroup);
if (IS_ERR(domain->domain))
goto out;
ret = iommu_group_for_each_dev(pgroup, &pbus,
vfio_sdmdev_bus);
if (ret < 0)
goto err_with_share_domain;

if (pbus && iommu_capable(pbus, IOMMU_CAP_CACHE_COHERENCY))
domain->prot |= IOMMU_CACHE;

group->parent_group = pgroup;
INIT_LIST_HEAD(&domain->group_list);
list_add(&group->next, &domain->group_list);

return 1;
}

return 0;

err_with_share_domain:
iommu_group_unshare_domain(pgroup);
out:
return ret;
}

static int vfio_iommu_type1_attach_group(void *iommu_data,
struct iommu_group *iommu_group)
{
Expand All @@ -1335,8 +1441,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
struct vfio_domain *domain, *d;
struct bus_type *bus = NULL, *mdev_bus;
int ret;
bool resv_msi, msi_remap;
phys_addr_t resv_msi_base;
bool resv_msi = false, msi_remap;
phys_addr_t resv_msi_base = 0;

mutex_lock(&iommu->lock);

Expand Down Expand Up @@ -1373,6 +1479,14 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
if (mdev_bus) {
if ((bus == mdev_bus) && !iommu_present(bus)) {
symbol_put(mdev_bus_type);

ret = vfio_iommu_type1_attach_sdgroup(domain, group,
iommu_group);
if (ret < 0)
goto out_free;
else if (ret > 0)
goto replay_check;

if (!iommu->external_domain) {
INIT_LIST_HEAD(&domain->group_list);
iommu->external_domain = domain;
Expand Down Expand Up @@ -1451,12 +1565,13 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,

vfio_test_domain_fgsp(domain);

replay_check:
/* replay mappings on new domains */
ret = vfio_iommu_replay(iommu, domain);
if (ret)
goto out_detach;

if (resv_msi) {
if (!group->parent_group && resv_msi) {
ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
if (ret)
goto out_detach;
Expand All @@ -1471,7 +1586,10 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
out_detach:
iommu_detach_group(domain->domain, iommu_group);
out_domain:
iommu_domain_free(domain->domain);
if (group->parent_group)
iommu_group_unshare_domain(group->parent_group);
else
iommu_domain_free(domain->domain);
out_free:
kfree(domain);
kfree(group);
Expand Down Expand Up @@ -1527,12 +1645,25 @@ static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
WARN_ON(iommu->notifier.head);
}

static void vfio_iommu_undo(struct vfio_iommu *iommu,
struct iommu_domain *domain)
{
struct rb_node *n = rb_first(&iommu->dma_list);
struct vfio_dma *dma;

for (; n; n = rb_next(n)) {
dma = rb_entry(n, struct vfio_dma, node);
iommu_unmap(domain, dma->iova, dma->size);
}
}

static void vfio_iommu_type1_detach_group(void *iommu_data,
struct iommu_group *iommu_group)
{
struct vfio_iommu *iommu = iommu_data;
struct vfio_domain *domain;
struct vfio_group *group;
struct iommu_domain *sdomain = NULL;

mutex_lock(&iommu->lock);

Expand Down Expand Up @@ -1560,7 +1691,12 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
if (!group)
continue;

iommu_detach_group(domain->domain, iommu_group);
if (group->parent_group)
sdomain = iommu_group_unshare_domain(
group->parent_group);
else
iommu_detach_group(domain->domain, iommu_group);

list_del(&group->next);
kfree(group);
/*
Expand All @@ -1577,7 +1713,10 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
else
vfio_iommu_unmap_unpin_reaccount(iommu);
}
iommu_domain_free(domain->domain);
if (domain->domain != sdomain)
iommu_domain_free(domain->domain);
else
vfio_iommu_undo(iommu, sdomain);
list_del(&domain->next);
kfree(domain);
}
Expand Down
96 changes: 96 additions & 0 deletions include/linux/vfio_sdmdev.h
@@ -0,0 +1,96 @@
/* SPDX-License-Identifier: GPL-2.0+ */
#ifndef __VFIO_SDMDEV_H
#define __VFIO_SDMDEV_H

#include <linux/device.h>
#include <linux/iommu.h>
#include <linux/mdev.h>
#include <linux/vfio.h>
#include <uapi/linux/vfio_sdmdev.h>

struct vfio_sdmdev_queue;
struct vfio_sdmdev;

/* event bit used to mask the hardware irq */
#define VFIO_SDMDEV_EVENT_Q_UPDATE BIT(0) /* irq if queue is updated */

/**
* struct vfio_sdmdev_ops - WD device operations
* @get_queue: get a queue from the device according to algorithm
* @put_queue: free a queue to the device
* @start_queue: put queue into action with current process's pasid.
* @stop_queue: stop queue from running state
* @is_q_updated: check whether the task is finished
* @mask_notify: mask the task irq of queue
* @mmap: mmap addresses of queue to user space
* @reset: reset the WD device
* @reset_queue: reset the queue
* @ioctl: ioctl for user space users of the queue
* @get_available_instances: get numbers of the queue remained
*/
struct vfio_sdmdev_ops {
int (*get_queue)(struct vfio_sdmdev *sdmdev,
struct vfio_sdmdev_queue **q);
void (*put_queue)(struct vfio_sdmdev_queue *q);
int (*start_queue)(struct vfio_sdmdev_queue *q);
void (*stop_queue)(struct vfio_sdmdev_queue *q);
int (*is_q_updated)(struct vfio_sdmdev_queue *q);
void (*mask_notify)(struct vfio_sdmdev_queue *q, int event_mask);
int (*mmap)(struct vfio_sdmdev_queue *q, struct vm_area_struct *vma);
int (*reset)(struct vfio_sdmdev *sdmdev);
int (*reset_queue)(struct vfio_sdmdev_queue *q);
long (*ioctl)(struct vfio_sdmdev_queue *q, unsigned int cmd,
unsigned long arg);
int (*get_available_instances)(struct vfio_sdmdev *sdmdev);
};

struct vfio_sdmdev_queue {
struct mutex mutex;
struct vfio_sdmdev *sdmdev;
__u32 flags;
void *priv;
wait_queue_head_t wait;
struct mdev_device *mdev;
int fd;
int container;
#ifdef CONFIG_IOMMU_SVA
int pasid;
#endif
};

struct vfio_sdmdev {
const char *name;
int status;
atomic_t ref;
const struct vfio_sdmdev_ops *ops;
struct device *dev;
struct device cls_dev;
bool is_vf;
u32 iommu_type;
u32 dma_flag;
void *priv;
int flags;
const char *api_ver;
struct mdev_parent_ops mdev_fops;
};

int vfio_sdmdev_register(struct vfio_sdmdev *sdmdev);
void vfio_sdmdev_unregister(struct vfio_sdmdev *sdmdev);
void vfio_sdmdev_wake_up(struct vfio_sdmdev_queue *q);
int vfio_sdmdev_is_sdmdev(struct device *dev);
struct vfio_sdmdev *vfio_sdmdev_pdev_sdmdev(struct device *dev);
struct vfio_sdmdev *mdev_sdmdev(struct mdev_device *mdev);

extern struct mdev_type_attribute mdev_type_attr_flags;
extern struct mdev_type_attribute mdev_type_attr_name;
extern struct mdev_type_attribute mdev_type_attr_device_api;
extern struct mdev_type_attribute mdev_type_attr_available_instances;
#define VFIO_SDMDEV_DEFAULT_MDEV_TYPE_ATTRS \
&mdev_type_attr_name.attr, \
&mdev_type_attr_device_api.attr, \
&mdev_type_attr_available_instances.attr, \
&mdev_type_attr_flags.attr

#define _VFIO_SDMDEV_REGION(vm_pgoff) (vm_pgoff & 0xf)

#endif
29 changes: 29 additions & 0 deletions include/uapi/linux/vfio_sdmdev.h
@@ -0,0 +1,29 @@
/* SPDX-License-Identifier: GPL-2.0+ */
#ifndef _UAPIVFIO_SDMDEV_H
#define _UAPIVFIO_SDMDEV_H

#include <linux/ioctl.h>

#define VFIO_SDMDEV_CLASS_NAME "sdmdev"

/* Device ATTRs in parent dev SYSFS DIR */
#define VFIO_SDMDEV_PDEV_ATTRS_GRP_NAME "params"

/* Parent device attributes */
#define SDMDEV_IOMMU_TYPE "iommu_type"
#define SDMDEV_DMA_FLAG "dma_flag"

/* Maximum length of algorithm name string */
#define VFIO_SDMDEV_ALG_NAME_SIZE 64

/* the bits used in SDMDEV_DMA_FLAG attributes */
#define VFIO_SDMDEV_DMA_INVALID 0
#define VFIO_SDMDEV_DMA_SINGLE_PROC_MAP 1
#define VFIO_SDMDEV_DMA_MULTI_PROC_MAP 2
#define VFIO_SDMDEV_DMA_SVM 4
#define VFIO_SDMDEV_DMA_SVM_NO_FAULT 8
#define VFIO_SDMDEV_DMA_PHY 16

#define VFIO_SDMDEV_CMD_WAIT _IO('W', 1)
#define VFIO_SDMDEV_CMD_BIND_PASID _IO('W', 2)
#endif